diff --git a/CMakeLists.txt b/CMakeLists.txt index cb0c19d7ee..1d6b909fe7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,8 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) - set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND}) + set(FLAGTREE_TLE OFF) + remove_definitions(-D__TLE__) elseif(FLAGTREE_BACKEND STREQUAL "aipu") set(CMAKE_C_COMPILER clang-16) set(CMAKE_CXX_COMPILER clang++-16) @@ -281,14 +282,19 @@ if(TRITON_BUILD_PYTHON_MODULE) include_directories(${PROJECT_BINARY_DIR}/third_party/${FLAGTREE_BACKEND}) add_subdirectory(third_party/hcu/proton/Dialect) add_subdirectory(third_party/nvidia) + elseif(FLAGTREE_BACKEND AND FLAGTREE_BACKEND STREQUAL "mthreads") + include_directories(${PROJECT_BINARY_DIR}/third_party/${FLAGTREE_BACKEND}) + add_subdirectory(third_party/mthreads/proton/Dialect) else() list(APPEND TRITON_PLUGIN_NAMES "proton") add_subdirectory(third_party/proton/Dialect) endif() # Add TLE plugin - list(APPEND TRITON_PLUGIN_NAMES "tle") - add_subdirectory(third_party/tle) + if(FLAGTREE_TLE) + list(APPEND TRITON_PLUGIN_NAMES "tle") + add_subdirectory(third_party/tle) + endif() if (DEFINED TRITON_PLUGIN_DIRS) foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS}) @@ -499,7 +505,9 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) endforeach() add_subdirectory(third_party/proton/Dialect) # flagtree tle - add_subdirectory(third_party/tle) + if(FLAGTREE_TLE) + add_subdirectory(third_party/tle) + endif() endif() find_package(Threads REQUIRED) diff --git a/python/setup_tools/utils/mthreads.py b/python/setup_tools/utils/mthreads.py new file mode 100644 index 0000000000..13ed22615b --- /dev/null +++ b/python/setup_tools/utils/mthreads.py @@ -0,0 +1,123 @@ +import sys +import shutil +import inspect +from pathlib import Path + +from setuptools import find_packages + +MTHREADS_PYTHON_ROOT = "third_party/mthreads/python" +FLAGTREE_PYTHON_ROOT = "python" +TLE_PACKAGE = "triton.experimental.tle" + + +def skip_package_dir(package): + return package == "triton" or package.startswith("triton.") + + +def get_package_dir(): + return { + "": MTHREADS_PYTHON_ROOT, + } + + +def _is_backend_package(package): + return package == "triton.backends" or package.startswith("triton.backends.") + + +def _is_language_extra_package(package): + return package == "triton.language.extra" or package.startswith("triton.language.extra.") + + +def _merge_mthreads_packages(existing_packages): + packages = [] + seen = set() + + def add(package): + if package not in seen: + packages.append(package) + seen.add(package) + + for package in find_packages(where=MTHREADS_PYTHON_ROOT, include=["triton", "triton.*"]): + add(package) + + for package in find_packages(where=FLAGTREE_PYTHON_ROOT, include=[TLE_PACKAGE, f"{TLE_PACKAGE}.*"]): + add(package) + + for package in existing_packages: + if (not package.startswith("triton.") or _is_backend_package(package) or _is_language_extra_package(package) + or package == "triton.profiler" or package.startswith("triton.profiler.")): + add(package) + + return packages + + +def _merge_mthreads_package_dir(existing_package_dir): + package_dir = dict(existing_package_dir or {}) + package_dir[""] = MTHREADS_PYTHON_ROOT + + for package in find_packages(where=MTHREADS_PYTHON_ROOT, include=["triton", "triton.*"]): + rel_package_path = package.replace(".", "/") + package_dir[package] = f"{MTHREADS_PYTHON_ROOT}/{rel_package_path}" + + for package in find_packages(where=FLAGTREE_PYTHON_ROOT, include=[TLE_PACKAGE, f"{TLE_PACKAGE}.*"]): + rel_package_path = package.replace(".", "/") + package_dir[package] = f"{FLAGTREE_PYTHON_ROOT}/{rel_package_path}" + + return package_dir + + +def _patch_mthreads_cmdclass(existing_cmdclass): + cmdclass = dict(existing_cmdclass or {}) + original_build_py = cmdclass.get("build_py") + if original_build_py is None: + return cmdclass + + class MthreadsBuildPy(original_build_py): + + def run(self): + self.force = True + build_triton_dir = Path(self.build_lib) / "triton" + if build_triton_dir.exists(): + shutil.rmtree(build_triton_dir) + return super().run() + + cmdclass["build_py"] = MthreadsBuildPy + return cmdclass + + +def _wrap_setup(original_setup): + if getattr(original_setup, "_mthreads_python_root_patched", False): + return original_setup + + def setup_with_mthreads_python_root(*args, **kwargs): + kwargs["packages"] = _merge_mthreads_packages(kwargs.get("packages", [])) + kwargs["package_dir"] = _merge_mthreads_package_dir(kwargs.get("package_dir", {})) + kwargs["cmdclass"] = _patch_mthreads_cmdclass(kwargs.get("cmdclass", {})) + return original_setup(*args, **kwargs) + + setup_with_mthreads_python_root._mthreads_python_root_patched = True + setup_with_mthreads_python_root._mthreads_original_setup = original_setup + return setup_with_mthreads_python_root + + +def _patch_setup_for_mthreads_python_root(): + patched = False + + frame = inspect.currentframe() + while frame is not None: + setup_func = frame.f_globals.get("setup") + if callable(setup_func): + frame.f_globals["setup"] = _wrap_setup(setup_func) + patched = True + frame = frame.f_back + + main_module = sys.modules.get("__main__") + if main_module is not None and hasattr(main_module, "setup"): + main_module.setup = _wrap_setup(main_module.setup) + patched = True + + if not patched: + raise RuntimeError("mthreads setup hook could not find setup() to patch") + + +_patch_setup_for_mthreads_python_root() diff --git a/setup.py b/setup.py index 4ea6801692..dee1fed33e 100644 --- a/setup.py +++ b/setup.py @@ -704,8 +704,6 @@ def get_packages(): if helper.flagtree_backend == "xpu": yield f"triton.language.extra.xpu" - elif helper.flagtree_backend == "mthreads": - yield f"triton/language/extra/musa" if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON yield "triton.profiler" diff --git a/third_party/mthreads/CMakeLists.txt b/third_party/mthreads/CMakeLists.txt new file mode 100644 index 0000000000..61a004cd8d --- /dev/null +++ b/third_party/mthreads/CMakeLists.txt @@ -0,0 +1,23 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/musa/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/musa/include) +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(musa) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonMthreads ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads.cc + LINK_LIBS TritonMUSAGPUToLLVM MTGPUToLLVM TritonMUSAGPUTransforms) + add_dependencies(TritonMthreads + MUSATableGen + MUSAAttrDefsIncGen + MTGPUTableGen + MTGPUTypesIncGen + MTGPUConversionPassIncGen + TritonMUSAGPUConversionPassIncGen + TritonMUSAGPUTransformsIncGen) + target_link_libraries(TritonMthreads PRIVATE Python3::Module pybind11::headers) +endif() +add_subdirectory(bin) diff --git a/third_party/mthreads/backend/__init__.py b/third_party/mthreads/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/mthreads/backend/compiler.py b/third_party/mthreads/backend/compiler.py new file mode 100644 index 0000000000..5492e596fb --- /dev/null +++ b/third_party/mthreads/backend/compiler.py @@ -0,0 +1,933 @@ +from triton.backends.compiler import BaseBackend, GPUTarget, Language +from triton._C.libtriton import ir, passes, mthreads +from triton import knobs + +from dataclasses import dataclass +from pathlib import Path +import functools +from typing import Any, Dict, Tuple, Optional +import hashlib +import os +import re +import shutil +import shlex +import subprocess +import tempfile + + +def min_dot_size(target: GPUTarget): + + def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: + lhs_bitwidth = lhs_type.scalar.primitive_bitwidth + rhs_bitwidth = rhs_type.scalar.primitive_bitwidth + assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same" + return (1, 1, 1) + + return check_dot_compatibility + + +def _module_text(mod) -> str: + try: + return str(mod) + except Exception: + return "" + + +def _module_uses_sqmma(mod) -> bool: + text = _module_text(mod) + return "mtgpu.sqmma" in text + + +@functools.lru_cache() +def get_musa_version() -> str: + if env_ver := os.getenv("TRITON_MUSA_VERSION"): + return env_ver + try: + import torch_musa # type: ignore + return getattr(torch_musa, "__version__", "unknown") + except Exception: + return "unknown" + + +@functools.lru_cache(None) +def file_hash(path: str) -> str: + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +@functools.lru_cache(None) +def _tool_version_signature(path: str) -> str: + norm = _normalize_path(path) + if not norm: + return "" + tool_path = str(Path(norm).expanduser()) + version_text = "" + try: + out = subprocess.check_output([tool_path, "--version"], stderr=subprocess.STDOUT, text=True) + version_text = out.strip() + except Exception: + version_text = "" + binary_hash = "" + try: + if Path(tool_path).exists(): + binary_hash = file_hash(tool_path) + except Exception: + binary_hash = "" + return f"{tool_path}|{version_text}|{binary_hash}" + + +def _normalize_arch(arch: object) -> str: + if isinstance(arch, int): + return str(arch) + return str(arch).lower() + + +def _capability_from_arch(arch: object) -> int: + if isinstance(arch, int): + return arch + arch_str = _normalize_arch(arch) + if arch_str.isdigit(): + return int(arch_str) + if arch_str.startswith("ph1"): + return 31 + raise ValueError(f"Unsupported MUSA arch: {arch}") + + +def _normalize_path(path: Optional[str]) -> Optional[str]: + if not path: + return None + return str(Path(path).expanduser()) + + +def _maybe_tool_path(tool) -> Optional[str]: + try: + return _normalize_path(tool.path) + except Exception: + return None + + +def _select_tool_path(explicit_path: Optional[str], tool) -> Optional[str]: + path = _normalize_path(explicit_path) + if path: + return path + return _maybe_tool_path(tool) + + +def _resolve_toolchain_paths(options: "MUSAOptions") -> Tuple[str, str, Optional[str]]: + toolchain_path = _normalize_path(options.toolchain_path) + llc_path = _normalize_path(options.llc_path) + lld_path = _normalize_path(options.lld_path) + llc_asm_path = _normalize_path(options.llc_asm_path) + + if not toolchain_path: + mtcc_bin_path = os.getenv("MTCC_BIN_PATH") + if mtcc_bin_path: + toolchain_path = str(Path(mtcc_bin_path).expanduser()) + if not toolchain_path: + musa_home = os.getenv("MUSA_HOME") + if musa_home: + toolchain_path = str(Path(musa_home).expanduser() / "bin") + + if not llc_path and toolchain_path: + llc_path = str(Path(toolchain_path) / "llc") + if not lld_path and toolchain_path: + lld_path = str(Path(toolchain_path) / "ld.lld") + + return llc_path or "", lld_path or "", llc_asm_path + + +@functools.lru_cache(None) +def _detect_llvm_major_version(llc_path: str) -> Optional[int]: + llc = str(Path(llc_path).expanduser()) if llc_path else "" + if not llc: + return None + try: + out = subprocess.check_output([llc, "--version"], stderr=subprocess.STDOUT, text=True) + except Exception: + return None + match = re.search(r"LLVM version\s+(\d+)\.", out) + if not match: + return None + try: + return int(match.group(1)) + except Exception: + return None + + +def _tool_output(stdout: Optional[str], stderr: Optional[str]) -> str: + chunks = [] + if stdout and stdout.strip(): + chunks.append(stdout.strip()) + if stderr and stderr.strip(): + chunks.append(stderr.strip()) + return "\n".join(chunks) + + +def _run_tool_command(tool_name: str, cmd: list[str], *, repro_dir: Path, dump_log: bool = False) -> None: + proc = subprocess.run(cmd, check=False, text=True, capture_output=True) + output = _tool_output(proc.stdout, proc.stderr) + if dump_log and output: + print(f"// -----// MUSA {tool_name} Log //----- //") + print(output) + if proc.returncode == 0: + return + error = (f"`{tool_name}` failed with error code {proc.returncode}\n" + f"`{tool_name}` output:\n{output or ''}\n" + f"Repro command: {shlex.join(cmd)}\n" + f"Artifacts kept in: {repro_dir}") + raise RuntimeError(error) + + +def _should_apply_llvm_compat(llc_major: Optional[int]) -> bool: + return llc_major is None or llc_major < 19 + + +def _llc_opaque_pointer_options(llc_major: Optional[int]) -> list[str]: + return ["--opaque-pointers"] if llc_major is not None and llc_major < 15 else [] + + +def _strip_range_attributes(ir_text: str) -> str: + out = ir_text + pos = 0 + call_ret_re = re.compile(r"[^,\n@][^,\n@]*\s+@[A-Za-z_$.][A-Za-z0-9_$.]*\s*\(") + while True: + start = out.find("range(", pos) + if start < 0: + break + cur = start + len("range(") + depth = 1 + while cur < len(out) and depth > 0: + ch = out[cur] + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + cur += 1 + if depth != 0: + pos = start + 1 + continue + end = cur + while end < len(out) and out[end].isspace(): + end += 1 + tail = out[end:] + if end < len(out) and (out[end] == "%" or call_ret_re.match(tail)): + out = out[:start] + out[end:] + pos = start + else: + pos = end + return out + + +def _rewrite_bare_splat_operands(ir_text: str) -> str: + vec_re = re.compile(r"<\s*(\d+)\s+x\s*([A-Za-z0-9_.]+)\s*>") + bare_splat_re = re.compile(r"splat\s*\(\s*([A-Za-z0-9_.]+)\s+([^)]+)\s*\)") + + out_lines = [] + for line in ir_text.splitlines(): + search_pos = 0 + while search_pos < len(line): + match = bare_splat_re.search(line, search_pos) + if match is None: + break + elem_ty = match.group(1) + elem_val = match.group(2).strip() + prefix = line[:match.start()] + + lane_count = -1 + for vec_match in vec_re.finditer(prefix): + if vec_match.group(2) == elem_ty: + lane_count = int(vec_match.group(1)) + + if lane_count <= 0: + search_pos = match.end() + continue + + lane_str = str(lane_count) + vec_ty = f"<{lane_str} x {elem_ty}>" + mask_ty = f"<{lane_str} x i32>" + insert_expr = (f"insertelement ({vec_ty} undef, {elem_ty} {elem_val}, i32 0)") + replacement = (f"shufflevector ({vec_ty} {insert_expr}, {vec_ty} undef, " + f"{mask_ty} zeroinitializer)") + + line = line[:match.start()] + replacement + line[match.end():] + search_pos = match.start() + len(replacement) + + out_lines.append(line) + return "\n".join(out_lines) + ("\n" if ir_text.endswith("\n") else "") + + +def _rewrite_musa_isspacep_shared(ir_text: str) -> str: + call_re = re.compile( + r"^([ \t]*)(%[A-Za-z0-9_.]+|%\d+)\s*=\s*(?:tail\s+)?call\s+i1\s+" + r"@llvm\.musa\.isspacep\.shared\s*\(\s*ptr(?:\s+[^()%]+)*\s+(%[A-Za-z0-9_.]+|%\d+)\s*\)\s*(,.*)?$") + + def _tmp_name(base_pred: str, kind: str) -> str: + if re.fullmatch(r"%\d+", base_pred): + return f"%musa_isspacep_{kind}_{base_pred[1:]}" + return f"{base_pred}.isspacep.{kind}" + + out_lines = [] + for line in ir_text.splitlines(): + m = call_re.match(line) + if m is None: + out_lines.append(line) + continue + + indent, pred_name, ptr_name, dbg_suffix = m.groups() + dbg_suffix = dbg_suffix or "" + ptr_i64 = _tmp_name(pred_name, "i64") + ptr_hi32 = _tmp_name(pred_name, "hi32") + out_lines.append(f"{indent}{ptr_i64} = ptrtoint ptr {ptr_name} to i64{dbg_suffix}") + out_lines.append(f"{indent}{ptr_hi32} = lshr i64 {ptr_i64}, 32{dbg_suffix}") + out_lines.append(f"{indent}{pred_name} = icmp eq i64 {ptr_hi32}, 0{dbg_suffix}") + + out = "\n".join(out_lines) + if ir_text.endswith("\n"): + out += "\n" + + out = re.sub( + r"(?m)^[ \t]*declare\s+i1\s+@llvm\.musa\.isspacep\.shared\s*\(\s*ptr\s*\)\s*(?:#\d+)?\s*\n?", + "", + out, + ) + return out + + +def _rewrite_musa_ptr_gen_to_addrspace(ir_text: str) -> str: + specs = [("global", 1), ("shared", 3)] + ptr_as_map: Dict[str, int] = {} + out_lines = [] + + for line in ir_text.splitlines(): + rewritten = False + for space_name, as_id in specs: + call_re = re.compile( + rf"^([ \t]*)(%[A-Za-z0-9_.]+|%\d+)\s*=\s*(?:tail\s+)?call\s+ptr\s+" + rf"@llvm\.musa\.ptr\.gen\.to\.{space_name}\s*\(\s*ptr(?:\s+[^()%]+)*\s+(%[A-Za-z0-9_.]+|%\d+)\s*\)\s*(,.*)?$" + ) + m = call_re.match(line) + if m is None: + continue + indent, out_ptr, in_ptr, dbg_suffix = m.groups() + dbg_suffix = dbg_suffix or "" + out_lines.append(f"{indent}{out_ptr} = addrspacecast ptr {in_ptr} to ptr addrspace({as_id}){dbg_suffix}") + ptr_as_map[out_ptr] = as_id + rewritten = True + break + if not rewritten: + out_lines.append(line) + + out = "\n".join(out_lines) + if ir_text.endswith("\n"): + out += "\n" + for space_name, _ in specs: + out = re.sub( + rf"(?m)^[ \t]*declare\s+ptr\s+@llvm\.musa\.ptr\.gen\.to\.{space_name}\s*\(\s*ptr\s*\)\s*(?:#\d+)?\s*\n?", + "", + out, + ) + + for ptr_name, as_id in ptr_as_map.items(): + out = re.sub( + rf"\bcmpxchg\s+ptr\s+{re.escape(ptr_name)}\b", + f"cmpxchg ptr addrspace({as_id}) {ptr_name}", + out, + ) + return out + + +def _rewrite_llvm_is_fpclass_f32(ir_text: str) -> str: + call_re = re.compile(r"^([ \t]*)(%[A-Za-z0-9_.]+|%\d+)\s*=\s*(?:tail\s+)?call\s+i1\s+" + r"@llvm\.is\.fpclass\.f32\s*\(\s*float\s+([^,]+)\s*,\s*i32\s+64\s*\)\s*(,.*)?$") + out_lines = [] + changed = False + for line in ir_text.splitlines(): + m = call_re.match(line) + if m is None: + out_lines.append(line) + continue + indent, pred_name, val, dbg_suffix = m.groups() + dbg_suffix = dbg_suffix or "" + out_lines.append(f"{indent}{pred_name} = fcmp oeq float {val.strip()}, 0.000000e+00{dbg_suffix}") + changed = True + + out = "\n".join(out_lines) + if ir_text.endswith("\n"): + out += "\n" + if not changed: + return out + + out = re.sub( + r"(?m)^[ \t]*declare\s+i1\s+@llvm\.is\.fpclass\.f32\s*" + r"\(\s*float\s*,\s*i32\s+immarg\s*\)\s*(?:#\d+)?\s*\n?", + "", + out, + ) + return out + + +def _rewrite_lifetime_intrinsics_for_llvm14(ir_text: str) -> str: + out = ir_text + + out = re.sub( + r"(?m)^([ \t]*(?:tail\s+|musttail\s+|notail\s+)?call\s+void\s+@llvm\.lifetime\.(start|end)\.p0)" + r"\(\s*ptr(\s+[^()%,]+(?:\s+[^()%,]+)*)?\s+([^,)]+)\s*\)", + r"\1(i64 -1, ptr\3 \4)", + out, + ) + + out = re.sub( + r"(?m)^([ \t]*declare\s+void\s+@llvm\.lifetime\.(start|end)\.p0)" + r"\(\s*ptr(\s+[^()%,]+(?:\s+[^()%,]+)*)?\s*\)", + r"\1(i64 immarg, ptr\3)", + out, + ) + + return out + + +_SCMP_UCMP_CALL_RE = re.compile(r"^(\s*)(%\w+)\s*=\s*(?:tail\s+|musttail\s+|notail\s+)?call\s+(?Pi\d+)\s+" + r"@llvm\.(?Pscmp|ucmp)\.(?Pi\d+)\.(?P=opty)\s*" + r"\(\s*(?P=opty)\s+(?P[^,]+)\s*,\s*(?P=opty)\s+(?P[^)]+)\)\s*(?P.*)$") + + +def _rewrite_llvm_scmp_ucmp_to_icmp(ir_text: str) -> str: + pred = {"scmp": ("slt", "sgt"), "ucmp": ("ult", "ugt")} + out_lines: list[str] = [] + counter = 0 + for line in ir_text.splitlines(): + m = _SCMP_UCMP_CALL_RE.match(line) + if not m: + out_lines.append(line) + continue + counter += 1 + indent = m.group(1) + result = m.group(2) + ret_ty = m.group("ret") + kind = m.group("kind") + opty = m.group("opty") + a = m.group("a").strip() + b = m.group("b").strip() + tail = m.group("tail").rstrip() + p_lo, p_hi = pred[kind] + lt = f"%.musa_scmp_lt_{counter}" + gt = f"%.musa_scmp_gt_{counter}" + mid = f"%.musa_scmp_mid_{counter}" + out_lines.append(f"{indent}{lt} = icmp {p_lo} {opty} {a}, {b}") + out_lines.append(f"{indent}{gt} = icmp {p_hi} {opty} {a}, {b}") + out_lines.append(f"{indent}{mid} = select i1 {gt}, {ret_ty} 1, {ret_ty} 0") + last = f"{indent}{result} = select i1 {lt}, {ret_ty} -1, {ret_ty} {mid}" + if tail: + last = f"{last} {tail}" + out_lines.append(last) + + out = "\n".join(out_lines) + if ir_text.endswith("\n"): + out += "\n" + out = re.sub( + r"(?m)^[ \t]*declare\s+i\d+\s+@llvm\.(?:scmp|ucmp)\.i\d+\.i\d+\s*" + r"\(\s*i\d+\s*,\s*i\d+\s*\)[^\n]*\n", + "", + out, + ) + return out + + +def _llvm_compat(ir_text: str) -> str: + replacements = [ + ("memory\\(none\\)", "readnone"), + ("memory\\(read\\)", "readonly"), + ("memory\\(write\\)", "writeonly"), + ("memory\\(argmem: readwrite\\)", "argmemonly"), + ("memory\\(argmem: read\\)", "argmemonly readonly"), + ("memory\\(argmem: write\\)", "argmemonly writeonly"), + ("memory\\(inaccessiblemem: readwrite\\)", "inaccessiblememonly"), + ("memory\\(inaccessiblemem: read\\)", "inaccessiblememonly readonly"), + ("memory\\(inaccessiblemem: write\\)", "inaccessiblememonly writeonly"), + ("memory\\(argmem: readwrite, inaccessiblemem: readwrite\\)", "inaccessiblemem_or_argmemonly"), + ("memory\\(argmem: read, inaccessiblemem: read\\)", "inaccessiblemem_or_argmemonly readonly"), + ("memory\\(argmem: write, inaccessiblemem: write\\)", "inaccessiblemem_or_argmemonly writeonly"), + ] + out = ir_text + for new, old in replacements: + out = re.sub(new, old, out) + + out = re.sub(r"\bicmp\s+samesign\b", "icmp", out) + + splat_re = re.compile(r"<(\d+)\s+x\s+([^>]+)>\s+splat\s*\(\s*\2\s+([^)]+)\)") + + def _expand_splat(match: re.Match) -> str: + count = int(match.group(1)) + ty = match.group(2) + val = match.group(3) + elems = ", ".join([f"{ty} {val}"] * count) + return f"<{count} x {ty}> <{elems}>" + + out = splat_re.sub(_expand_splat, out) + out = _rewrite_bare_splat_operands(out) + out = _strip_range_attributes(out) + out = re.sub(r"\s+captures\(\s*none\s*\)", " nocapture", out) + out = re.sub(r"\s+captures\([^)]*\)", "", out) + out = re.sub(r"\bor\s+disjoint\s+", "or ", out) + out = re.sub(r"\bzext\s+nneg\s+", "zext ", out) + out = re.sub(r"\bsext\s+nneg\s+", "sext ", out) + out = re.sub(r"\buitofp\s+nneg\s+", "uitofp ", out) + out = re.sub(r"\bsitofp\s+nneg\s+", "sitofp ", out) + out = re.sub(r"\btrunc\s+nuw\s+nsw\s+", "trunc ", out) + out = re.sub(r"\btrunc\s+nsw\s+nuw\s+", "trunc ", out) + out = re.sub(r"\btrunc\s+nuw\s+", "trunc ", out) + out = re.sub(r"\btrunc\s+nsw\s+", "trunc ", out) + out = re.sub(r"\bgetelementptr\s+inbounds\s+nusw\s+", "getelementptr inbounds ", out) + out = re.sub(r"\bgetelementptr\s+inbounds\s+nuw\s+", "getelementptr inbounds ", out) + out = re.sub(r"\bgetelementptr\s+inbounds\s+nsw\s+", "getelementptr inbounds ", out) + out = re.sub(r"\bgetelementptr\s+nusw\s+", "getelementptr ", out) + out = re.sub(r"\bgetelementptr\s+nuw\s+", "getelementptr ", out) + out = re.sub(r"\bgetelementptr\s+nsw\s+", "getelementptr ", out) + out = _rewrite_musa_isspacep_shared(out) + out = _rewrite_musa_ptr_gen_to_addrspace(out) + out = _rewrite_llvm_is_fpclass_f32(out) + out = _rewrite_lifetime_intrinsics_for_llvm14(out) + for attr in ("nocallback", "nocreateundeforpoison", "mustprogress", "speculatable", "willreturn"): + out = re.sub(rf"(? str: + for line in ir_text.splitlines(): + if "nvvm.annotations" in line and "\"kernel\"" in line and "@" in line: + m = re.search(r"@([A-Za-z_][A-Za-z0-9_\\.]+)", line) + if m: + return m.group(1) + + matches = re.findall(r"^define\s+[^@]*@([A-Za-z_][A-Za-z0-9_\.]*)", ir_text, flags=re.MULTILINE) + if matches: + return matches[0] + raise RuntimeError("Unable to determine kernel name from LLVM IR") + + +def _llc_extra_options(metadata: Dict[str, object], options: "MUSAOptions") -> list[str]: + uses_mulhi = bool(metadata.get("uses_mulhi_helper")) + const_calc_opt = [] if uses_mulhi else ["-mtgpu-enable-const-calc=1"] + + uses_sqmma = bool(metadata.get("uses_sqmma")) + enable_backend_opt = bool(options.enable_llc_opt or options.enable_backend_opt) + llc_options_map = { + (False, False): [*const_calc_opt], + (True, False): { + *const_calc_opt, + "-mtgpu-alloc-shared-memory-from-zero=1", + }, + (False, True): [ + "-mtgpu-enable-const-calc=1", + "-mtgpu-tiny-offset-hint=1", + "-mtgpu-combine-instr-with-burst=1", + "-mtgpu-combine-fop-instr=1", + ], + (True, True): [ + "-mtgpu-opt-level=1", + "-mtgpu-combine-instr-with-burst=1", + "-mtgpu-combine-fop-instr=1", + "-misched=mtgpu-max-ilp", + ], + } + opts = llc_options_map[(uses_sqmma, enable_backend_opt)] + if options.llc_options: + opts.extend(shlex.split(options.llc_options)) + return opts + + +@dataclass(frozen=True) +class MUSAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + warp_size: int = 32 + maxnreg: Optional[int] = None + enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False + supported_fp8_dtypes: Tuple[str, ...] = ("fp8e5", ) + supported_fp8_storage_dtypes: Tuple[str, ...] = ("fp8e5", ) + custom_fp8_dtypes: Tuple[str, ...] = () + deprecated_fp8_dot_operand_dtypes: Tuple[str, ...] = () + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str, ...] = ("ieee", "tf32", "tf32x3", "bf16x3", "bf16x6") + max_num_imprecise_acc_default: int = 0 + sanitize_overflow: bool = True + toolchain_path: Optional[str] = None + llc_path: Optional[str] = None + lld_path: Optional[str] = None + llc_asm_path: Optional[str] = None + llc_options: Optional[str] = None + enable_llc_opt: bool = False + enable_backend_opt: bool = False + enable_fp8_burst2: bool = False + enable_llvm_compat: bool = True + extern_libs: Optional[tuple] = None + debug: bool = False + backend_name: str = "musa" + supports_noinline: bool = True + arch: Optional[str] = None + instrumentation_mode: str = "" + + def __post_init__(self): + default_libdir = Path(__file__).parent / "lib" + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get("libdevice", None): + extern_libs["libdevice"] = knobs.musa.libdevice_path or str(default_libdir / "libdevice.31.bc") + object.__setattr__(self, "extern_libs", tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + hash_dict = dict(self.__dict__) + llc_path, lld_path, llc_asm_path = _resolve_toolchain_paths(self) + hash_dict["effective_llc_path"] = llc_path + hash_dict["effective_lld_path"] = lld_path + hash_dict["effective_llc_asm_path"] = llc_asm_path or "" + hash_dict["effective_llc_major"] = _detect_llvm_major_version(llc_path) + hash_dict["llc_tool_signature"] = _tool_version_signature(llc_path) + hash_dict["lld_tool_signature"] = _tool_version_signature(lld_path) + hash_dict["llc_asm_tool_signature"] = _tool_version_signature(llc_asm_path or "") + if hash_dict["extern_libs"]: + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class MUSABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "musa" + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.binary_ext = "mubin" + + def parse_options(self, opts) -> Any: + opts = dict(opts) + arch = knobs.runtime.override_arch or opts.get("arch", None) or self.target.arch + args = {"arch": _normalize_arch(arch)} + capability = _capability_from_arch(args["arch"]) + if opts.get("num_ctas", 1) > 1 and capability == 31: + raise ValueError("num_ctas > 1 requires MUSA cluster launch support. " + f"Current target is {args['arch']} (capability {capability}).") + if "enable_fp_fusion" not in opts: + args["enable_fp_fusion"] = knobs.language.default_fp_fusion + if "supported_fp8_dtypes" not in opts: + supported_fp8_dtypes = {"fp8e5"} + if capability >= 31: + supported_fp8_dtypes.add("fp8e4nv") + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + if "supported_fp8_storage_dtypes" not in opts: + supported_fp8_storage_dtypes = set(args.get("supported_fp8_dtypes", ())) + if capability >= 31: + supported_fp8_storage_dtypes.update({"fp8e4b15", "fp8e4b8", "fp8e5b16"}) + args["supported_fp8_storage_dtypes"] = tuple(sorted(supported_fp8_storage_dtypes)) + if "custom_fp8_dtypes" not in opts: + custom_fp8_dtypes = set() + if capability >= 31: + custom_fp8_dtypes.update({"fp8e4b15", "fp8e4b8", "fp8e5b16"}) + args["custom_fp8_dtypes"] = tuple(sorted(custom_fp8_dtypes)) + if "deprecated_fp8_dot_operand_dtypes" not in opts: + args["deprecated_fp8_dot_operand_dtypes"] = () + if "toolchain_path" not in opts: + toolchain_path = knobs.musa.toolchain_path + if not toolchain_path: + mtcc_bin_path = os.getenv("MTCC_BIN_PATH") + if mtcc_bin_path: + toolchain_path = mtcc_bin_path + else: + musa_home = os.getenv("MUSA_HOME") + toolchain_path = str(Path(musa_home) / "bin") if musa_home else None + args["toolchain_path"] = _normalize_path(toolchain_path) + if "llc_path" not in opts: + args["llc_path"] = _select_tool_path(knobs.musa.llc_path, knobs.musa.llc) + if "lld_path" not in opts: + args["lld_path"] = _select_tool_path(knobs.musa.lld_path, knobs.musa.lld) + if "llc_asm_path" not in opts: + args["llc_asm_path"] = _normalize_path(knobs.musa.llc_asm_path) + if "llc_options" not in opts: + args["llc_options"] = knobs.musa.llc_options + if "enable_llc_opt" not in opts: + args["enable_llc_opt"] = knobs.musa.enable_llc_opt + if "enable_fp8_burst2" not in opts: + args["enable_fp8_burst2"] = knobs.musa.enable_fp8_burst2 + if "enable_llvm_compat" not in opts: + args["enable_llvm_compat"] = knobs.musa.enable_llvm_compat + args.update({k: opts[k] for k in MUSAOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None}) + if "warp_size" not in args: + target_warp_size = getattr(self.target, "warp_size", None) + args["warp_size"] = int(target_warp_size) if target_warp_size else 32 + return MUSAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + ) + + def get_codegen_implementation(self, options): + from triton.language.extra.musa import utils as musa_utils + + return { + "convert_custom_types": musa_utils.convert_custom_float8, + "min_dot_size": min_dot_size(self.target), + } + + def get_module_map(self) -> Dict[str, object]: + try: + from triton.language.extra.musa import libdevice as musa_libdevice # type: ignore + libdevice = musa_libdevice + except Exception: + from triton.language.extra import libdevice + return {"triton.language.extra.libdevice": libdevice} + + def load_dialects(self, ctx): + mthreads.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + pm.run(mod, "make_ttir") + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, arch, capability): + if opt.maxnreg is not None: + mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) + + pm = ir.pass_manager(mod.context) + dump_enabled = pm.enable_debug() + emu_tf32 = capability >= 31 + + passes.ttir.add_convert_to_ttgpuir(pm, f"musa:{arch}", opt.num_warps, opt.warp_size, opt.num_ctas) + passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_f32_dot_tc(pm, emu_tf32) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + + mthreads.passes.ttgpuir.add_accelerate_matmul(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + mthreads.passes.ttgpuir.add_optimize_dot_operands(pm) + mthreads.passes.ttgpuir.add_optimize_descriptor_encoding(pm) + passes.ttir.add_loop_aware_cse(pm) + + if capability >= 31: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + mthreads.passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + mthreads.passes.ttgpuir.add_optimize_sqmma_accumulator_layout(pm) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + mthreads.passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + else: + passes.ttir.add_triton_licm(pm) + + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + mthreads.passes.ttgpuir.add_optimize_dot_operands(pm) + passes.ttgpuir.add_coalesce_async_copy(pm) + mthreads.passes.ttgpuir.add_tme_lowering(pm) + mthreads.passes.ttgpuir.add_optimize_sqmma_accumulator_layout(pm) + mthreads.passes.ttgpuir.add_canonicalize_sqmma_result_conversions(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + mthreads.passes.ttgpuir.add_issue_barrier_insertion(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + mthreads.passes.ttgpuir.add_convert_sqmma_to_mtgpu(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_sccp(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + if capability == 31: + mthreads.passes.ttgpuir.add_mark_inplace_loads(pm) + mthreads.passes.ttgpuir.add_finalize_barriers(pm) + pm.run(mod, "make_ttgir") + metadata["uses_sqmma"] = _module_uses_sqmma(mod) + metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() + return mod + + @staticmethod + def make_llir(src, metadata, options, arch): + from triton._C.libtriton import llvm + + mod = src + pm = ir.pass_manager(mod.context) + pm.enable_debug() + + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + mthreads.passes.ttgpuir.add_allocate_shared_memory(pm, _capability_from_arch(arch)) + mthreads.passes.ttgpuir.add_mtgpu_to_llvm(pm, _capability_from_arch(arch)) + mthreads.passes.ttgpuir.add_to_llvmir(pm, _capability_from_arch(arch)) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.convert.add_cf_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + + if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables: + passes.llvmir.add_di_scope(pm) + + pm.run(mod, "make_llir") + + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + mthreads.attach_datalayout(llvm_mod) + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + maxntidx = max(1, int(options.num_warps) * int(options.warp_size)) + kernel_name_hint = src.get_entry_func_name() if hasattr(src, "get_entry_func_name") else "" + mthreads.decorate_kernel_abi(llvm_mod, kernel_name_hint, maxntidx) + metadata["uses_mulhi_helper"] = mthreads.module_uses_mulhi_helper(llvm_mod) + + metadata["shared"] = src.get_int_attr("ttg.shared") + + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_mubin(src, metadata, opt, arch): + if not isinstance(src, str): + raise TypeError("Expected LLVM IR as a string for MUSA codegen") + + llc_path, lld_path, llc_asm_path = _resolve_toolchain_paths(opt) + if not llc_path or not lld_path: + raise RuntimeError("MUSA toolchain not configured. Set TRITON_MUSA_TOOLCHAIN_PATH " + "or TRITON_MUSA_LLC_PATH/TRITON_MUSA_LLD_PATH (or MUSA_HOME).") + + ir_text = src + llc_major = _detect_llvm_major_version(llc_path) + if opt.enable_llvm_compat: + if _should_apply_llvm_compat(llc_major): + ir_text = _llvm_compat(ir_text) + ir_text = _rewrite_llvm_scmp_ucmp_to_icmp(ir_text) + + if knobs.musa.dump_llir: + print("// -----// MUSA LLVMIR Dump //----- //") + print(ir_text) + + capability = _capability_from_arch(arch) + llc_opt_level = "-O2" + llc_opts = [ + "-march=mtgpu", + f"-mcpu=mp_{capability}", + *_llc_opaque_pointer_options(llc_major), + llc_opt_level, + "-filetype=obj", + ] + llc_opts.extend(_llc_extra_options(metadata, opt)) + + tmp_dir = tempfile.mkdtemp(prefix="triton-musa-") + tmp_path = Path(tmp_dir) + keep_artifacts = True + try: + tmp_path = Path(tmp_dir) + ll_file = tmp_path / "kernel.ll" + obj_file = tmp_path / "kernel.o" + mubin_file = tmp_path / "kernel.mubin" + + ll_file.write_text(ir_text) + + replace_llir = knobs.musa.replace_llir + if replace_llir and Path(replace_llir).exists(): + ll_file = Path(replace_llir) + kernel_name = _extract_kernel_name(ll_file.read_text()) + + if llc_asm_path: + llc_asm_major = _detect_llvm_major_version(llc_asm_path) + asm_file = tmp_path / "kernel.s" + asm_cmd = [ + llc_asm_path, + str(ll_file), + "-march=mtgpu", + f"-mcpu=mp_{capability}", + *_llc_opaque_pointer_options(llc_asm_major), + llc_opt_level, + "-filetype=asm", + "-o", + str(asm_file), + ] + asm_cmd.extend(_llc_extra_options(metadata, opt)) + _run_tool_command( + "llc-asm", + asm_cmd, + repro_dir=tmp_path, + dump_log=knobs.musa.dump_toolchain_log, + ) + if knobs.musa.dump_muasm: + print("// -----// MUASM Dump //----- //") + print(asm_file.read_text()) + + llc_cmd = [llc_path, str(ll_file), *llc_opts, "-o", str(obj_file)] + _run_tool_command( + "llc", + llc_cmd, + repro_dir=tmp_path, + dump_log=knobs.musa.dump_toolchain_log, + ) + + lld_cmd = [lld_path, "-flavor", "gnu", "-shared", str(obj_file), "-o", str(mubin_file)] + _run_tool_command( + "ld.lld", + lld_cmd, + repro_dir=tmp_path, + dump_log=knobs.musa.dump_toolchain_log, + ) + + replace_mubin = knobs.musa.replace_mubin + if replace_mubin and Path(replace_mubin).exists(): + mubin_file = Path(replace_mubin) + + metadata["name"] = kernel_name + result = mubin_file.read_bytes() + keep_artifacts = False + return result + finally: + if not keep_artifacts: + shutil.rmtree(tmp_dir, ignore_errors=True) + + def add_stages(self, stages, options, language): + arch = options.arch + capability = _capability_from_arch(arch) + if language == Language.TRITON: + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, arch, capability) + elif language == Language.GLUON: + raise RuntimeError("MUSA backend does not support GLUON yet") + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, arch) + stages["mubin"] = lambda src, metadata: self.make_mubin(src, metadata, options, arch) + if knobs.runtime.add_stages_inspection_hook is not None: + knobs.runtime.add_stages_inspection_hook(self, stages, options, language, arch) + + @functools.lru_cache() + def hash(self): + version = get_musa_version() + return f"{version}-{self.target.arch}" diff --git a/third_party/mthreads/backend/driver.c b/third_party/mthreads/backend/driver.c new file mode 100644 index 0000000000..6b9bcc6c4d --- /dev/null +++ b/third_party/mthreads/backend/driver.c @@ -0,0 +1,349 @@ +#include "musa.h" +#include +#include +#include +#include +#define PY_SSIZE_T_CLEAN +#include + +// Raises a Python exception and returns false if code is not MUSA_SUCCESS. +static bool gpuAssert(MUresult code, const char *file, int line) { + if (code == MUSA_SUCCESS) + return true; + + const char *prefix = "Triton Error [MUSA]: "; + const char *str; + muGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MUSA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + MUdevice device; + muDeviceGet(&device, device_id); + + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &max_shared_mem, MU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &max_num_regs, MU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &multiprocessor_count, MU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + MUSA_CHECK_AND_RETURN_NULL( + muDeviceGetAttribute(&warp_size, MU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &sm_clock_rate, MU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &mem_clock_rate, MU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &mem_bus_width, MU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + if (data_size == 0) { + PyErr_SetString(PyExc_RuntimeError, + "Empty MUSA binary: codegen is not available yet."); + return NULL; + } + MUfunction fun; + MUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + int32_t n_max_threads = 0; + MUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxGetCurrent(&pctx)); + if (!pctx) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muDevicePrimaryCtxRetain(&pctx, device)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxSetCurrent(pctx)); + } + + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muModuleLoadData(&mod, data)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muModuleGetFunction(&fun, mod, name)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncGetAttribute(&n_regs, MU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncGetAttribute(&n_spills, MU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muFuncGetAttribute( + &n_max_threads, MU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + + int shared_optin = 0; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muDeviceGetAttribute( + &shared_optin, MU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + + int shared_static = 0; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muFuncGetAttribute( + &shared_static, MU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + int max_dynamic_shared = shared_optin - shared_static; + if (max_dynamic_shared < 0) + max_dynamic_shared = 0; + int requested_dynamic_shared = shared; + if (requested_dynamic_shared > max_dynamic_shared) + requested_dynamic_shared = max_dynamic_shared; + if (requested_dynamic_shared > 0) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncSetAttribute(fun, MU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + requested_dynamic_shared)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills, n_max_threads); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + MUcontext ctx = NULL; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxGetCurrent(&ctx)); + if (!ctx) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxSetCurrent(ctx)); + } + + size_t oldSize = 0; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muCtxGetLimit(&oldSize, MU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != (size_t)size) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muCtxSetLimit(MU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); + return Py_None; +} + +static bool getTensorDescriptorDataType(int elementSize, + MUtensorDescriptorDataType *type) { + switch (elementSize) { + case 1: + *type = MU_TENSOR_DESCRIPTOR_DATA_TYPE_UINT8; + return true; + case 2: + *type = MU_TENSOR_DESCRIPTOR_DATA_TYPE_UINT16; + return true; + case 4: + *type = MU_TENSOR_DESCRIPTOR_DATA_TYPE_UINT32; + return true; + default: + PyErr_SetString(PyExc_ValueError, "element_size must be 1, 2, or 4 bytes"); + return false; + } +} + +static bool validateTMEDescriptorBlockBytes(unsigned rank, + const uint32_t *block_dims, + int element_size) { + uint64_t block_bytes = (uint64_t)element_size; + for (unsigned i = 0; i < rank; ++i) + block_bytes *= (uint64_t)block_dims[i]; + if (block_bytes >= 32) + return true; + + char err[64] = {0}; + snprintf(err, sizeof(err), "%uD block bytes must be >= 32", rank); + PyErr_SetString(PyExc_ValueError, err); + return false; +} + +static PyObject * +fillTMEDescriptorImpl(unsigned rank, unsigned long long global_address, + const uint64_t *dims, const uint32_t *block_dims, + int element_size, unsigned long long desc_address) { + MUtensorDescriptorDataType type; + if (!getTensorDescriptorDataType(element_size, &type)) + return NULL; + if (!validateTMEDescriptorBlockBytes(rank, block_dims, element_size)) + return NULL; + + uint64_t global_strides[5] = {0}; + global_strides[0] = dims[0] * (uint64_t)element_size; + for (unsigned i = 1; i < rank; ++i) + global_strides[i] = global_strides[i - 1] * dims[i]; + + MUtensorDescriptor desc; + MUSA_CHECK_AND_RETURN_NULL(muTensorDescriptorEncode( + &desc, type, /*tensorRank=*/rank, (void *)global_address, dims, + global_strides, MU_TENSOR_DESCRIPTOR_INTERLEAVE_NONE, /*swizzle=*/0)); + MUSA_CHECK_AND_RETURN_NULL( + muMemcpyHtoD((MUdeviceptr)desc_address, &desc, sizeof(desc))); + Py_INCREF(Py_None); + return Py_None; +} + +static PyObject *fill1DTMEDescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address = 0; + uint64_t dims[1]; + uint32_t block_dims[1]; + int element_size = 0; + unsigned long long desc_address = 0; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dims[0], + &block_dims[0], &element_size, &desc_address)) + return NULL; + + return fillTMEDescriptorImpl(/*rank=*/1, global_address, dims, block_dims, + element_size, desc_address); +} + +static PyObject *fill2DTMEDescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address = 0; + uint64_t dims[2]; + uint32_t block_dims[2]; + int element_size = 0; + unsigned long long desc_address = 0; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], + &block_dims[1], &block_dims[0], &element_size, + &desc_address)) + return NULL; + + return fillTMEDescriptorImpl(/*rank=*/2, global_address, dims, block_dims, + element_size, desc_address); +} + +static PyObject *fill3DTMEDescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address = 0; + uint64_t dims[3]; + uint32_t block_dims[3]; + int element_size = 0; + unsigned long long desc_address = 0; + if (!PyArg_ParseTuple(args, "KKKKiiiiK", &global_address, &dims[2], &dims[1], + &dims[0], &block_dims[2], &block_dims[1], + &block_dims[0], &element_size, &desc_address)) + return NULL; + + return fillTMEDescriptorImpl(/*rank=*/3, global_address, dims, block_dims, + element_size, desc_address); +} + +static PyObject *fill4DTMEDescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address = 0; + uint64_t dims[4]; + uint32_t block_dims[4]; + int element_size = 0; + unsigned long long desc_address = 0; + if (!PyArg_ParseTuple(args, "KKKKKiiiiiK", &global_address, &dims[3], + &dims[2], &dims[1], &dims[0], &block_dims[3], + &block_dims[2], &block_dims[1], &block_dims[0], + &element_size, &desc_address)) + return NULL; + + return fillTMEDescriptorImpl(/*rank=*/4, global_address, dims, block_dims, + element_size, desc_address); +} + +static PyObject *fill5DTMEDescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address = 0; + uint64_t dims[5]; + uint32_t block_dims[5]; + int element_size = 0; + unsigned long long desc_address = 0; + if (!PyArg_ParseTuple(args, "KKKKKKiiiiiiK", &global_address, &dims[4], + &dims[3], &dims[2], &dims[1], &dims[0], &block_dims[4], + &block_dims[3], &block_dims[2], &block_dims[1], + &block_dims[0], &element_size, &desc_address)) + return NULL; + + return fillTMEDescriptorImpl(/*rank=*/5, global_address, dims, block_dims, + element_size, desc_address); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided mubin into MUSA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Set printf FIFO size"}, + {"fill_1d_tma_descriptor", fill1DTMEDescriptor, METH_VARARGS, + "Fill a 1D TMA descriptor"}, + {"fill_2d_tma_descriptor", fill2DTMEDescriptor, METH_VARARGS, + "Fill a 2D TMA descriptor"}, + {"fill_3d_tma_descriptor", fill3DTMEDescriptor, METH_VARARGS, + "Fill a 3D TMA descriptor"}, + {"fill_4d_tma_descriptor", fill4DTMEDescriptor, METH_VARARGS, + "Fill a 4D TMA descriptor"}, + {"fill_5d_tma_descriptor", fill5DTMEDescriptor, METH_VARARGS, + "Fill a 5D TMA descriptor"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "musa_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_musa_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/mthreads/backend/driver.py b/third_party/mthreads/backend/driver.py new file mode 100644 index 0000000000..f5e496a0ac --- /dev/null +++ b/third_party/mthreads/backend/driver.py @@ -0,0 +1,895 @@ +import functools +import os +import subprocess +import weakref +from collections import OrderedDict +from pathlib import Path + +from triton import knobs +from triton.backends.compiler import GPUTarget +from triton.backends.driver import DriverBase +from triton.runtime.build import compile_module_from_src + +dirname = os.path.dirname(os.path.realpath(__file__)) +_TENSORDESC_CACHE_LIMIT = 1024 + + +def _split_paths(value: str): + return [p for p in value.split(":") if p] + + +@functools.lru_cache() +def _musa_home_dirs(): + candidates = [] + for key in ("MUSA_HOME", "MUSA_ROOT"): + if val := os.getenv(key): + candidates.append(val) + return candidates + + +@functools.lru_cache() +def _musa_include_dirs(): + include_dirs = [os.path.join(dirname, "include")] + if env_inc := os.getenv("TRITON_MUSA_INCLUDE_PATH"): + include_dirs.append(env_inc) + for home in _musa_home_dirs(): + include_dirs.append(os.path.join(home, "include")) + + # Validate that musa.h exists in one of the include dirs. + for inc in include_dirs: + if os.path.exists(os.path.join(inc, "musa.h")): + return include_dirs + raise RuntimeError("Cannot find musa.h. Set TRITON_MUSA_INCLUDE_PATH or MUSA_HOME/MUSA_ROOT to a valid MUSA SDK.") + + +@functools.lru_cache() +def _libmusa_dirs(): + + def has_libmusa(path: str) -> bool: + return (os.path.exists(os.path.join(path, "libmusa.so")) or os.path.exists(os.path.join(path, "libmusa.so.1"))) + + paths = [] + + if env_lib := os.getenv("TRITON_LIBMUSA_PATH") or os.getenv("TRITON_MUSA_LIB_PATH"): + if os.path.isfile(env_lib): + paths.append(os.path.dirname(env_lib)) + else: + paths.append(env_lib) + + for home in _musa_home_dirs(): + paths.append(os.path.join(home, "lib")) + paths.append(os.path.join(home, "lib64")) + + env_ld = os.getenv("LD_LIBRARY_PATH") + if env_ld: + paths.extend(_split_paths(env_ld)) + + # Try ldconfig cache + try: + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore") + locs = [line.split()[-1] for line in libs.splitlines() if "libmusa.so" in line] + paths.extend([os.path.dirname(loc) for loc in locs]) + except Exception: + pass + + # Filter to existing directories that contain libmusa. + valid = [p for p in paths if has_libmusa(p)] + if not valid: + raise RuntimeError( + "libmusa.so/libmusa.so.1 not found. Set TRITON_LIBMUSA_PATH/TRITON_MUSA_LIB_PATH or MUSA_HOME/MUSA_ROOT, " + "or update LD_LIBRARY_PATH.") + return valid + + +def _library_dirs(): + return [os.path.join(dirname, "lib"), *_libmusa_dirs()] + + +# ------------------------ +# Utils +# ------------------------ + + +class MusaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(MusaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + src = Path(os.path.join(dirname, "driver.c")).read_text() + mod = compile_module_from_src( + src=src, + name="musa_utils", + include_dirs=_musa_include_dirs(), + library_dirs=_library_dirs(), + libraries=["musa"], + ) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.set_printf_fifo_size = mod.set_printf_fifo_size + for name in ("fill_1d_tma_descriptor", "fill_2d_tma_descriptor", "fill_3d_tma_descriptor", + "fill_4d_tma_descriptor", "fill_5d_tma_descriptor"): + if hasattr(mod, name): + setattr(self, name, getattr(mod, name)) + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + # Align ABI mapping with NVIDIA/AMD for host-side signatures. + if ty[0] == '*': + return "MUdeviceptr" + if ty.startswith("tensordesc<"): + return "MUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "double", + "bf16": "double", + "fp32": "double", + "f32": "double", + "fp64": "double", + "constexpr": "int64_t", + }[ty] + + +def ty_to_cpp_param(ty): + if ty[0] == '*': + return "MUdeviceptr" + if ty.startswith("tensordesc<"): + return "MUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "uint16_t", + "bf16": "uint16_t", + "fp32": "float", + "f32": "float", + "fp64": "double", + "constexpr": "int64_t", + }[ty] + + +def _parse_tensordesc_type(ty: str): + if not isinstance(ty, str) or not ty.startswith("tensordesc<") or not ty.endswith(">"): + return None + body = ty[len("tensordesc<"):-1] + dtype, sep, shape = body.partition("[") + if not sep or not shape.endswith("]"): + return None + dims = [dim.strip() for dim in shape[:-1].split(",") if dim.strip()] + if not dtype or not dims: + return None + return dtype.strip(), len(dims) + + +def _get_tensordesc_abi_expanded_args(rank: int, metadata): + full_abi_args = 1 + 2 * rank + if metadata is None: + return full_abi_args + abi_args = int(metadata.get("abi_expanded_args", full_abi_args)) + if abi_args not in (1, full_abi_args): + raise ValueError( + f"unsupported MUSA tensor descriptor ABI expansion: expected 1 or {full_abi_args}, got {abi_args}") + return abi_args + + +def _expand_tensordesc_signature(signature_types, tensordesc_meta=None): + expanded_types = [] + expanded_index = {} + tensordesc_idx = 0 + for i, ty in enumerate(signature_types): + desc_info = _parse_tensordesc_type(ty) + if desc_info is None: + expanded_index[i] = [len(expanded_types)] + expanded_types.append(ty) + continue + + dtype, rank = desc_info + desc_meta = None + if tensordesc_meta is not None and tensordesc_idx < len(tensordesc_meta): + desc_meta = tensordesc_meta[tensordesc_idx] + abi_expanded_args = _get_tensordesc_abi_expanded_args(rank, desc_meta) + mapped = [] + expanded_types.append(f"*{dtype}") + mapped.append(len(expanded_types) - 1) + if abi_expanded_args != 1: + for _ in range(rank): + expanded_types.append("i32") + mapped.append(len(expanded_types) - 1) + for _ in range(rank): + expanded_types.append("i64") + mapped.append(len(expanded_types) - 1) + expanded_index[i] = mapped + tensordesc_idx += 1 + + return expanded_types, expanded_index + + +def _normalize_arg_path(key): + if isinstance(key, int): + return (key, ) + if isinstance(key, tuple): + return key + raise TypeError(f"unsupported signature path key: {key!r}") + + +def _expand_signature_tree(signature_types, tensordesc_meta=None): + expanded_types = [] + expanded_index = {} + tensordesc_idx = 0 + + def visit(ty, path): + nonlocal tensordesc_idx + + if isinstance(ty, tuple): + mapped = [] + for child_idx, child_ty in enumerate(ty): + mapped.extend(visit(child_ty, path + (child_idx, ))) + expanded_index[path] = mapped + return mapped + + desc_info = _parse_tensordesc_type(ty) + if desc_info is None: + expanded_index[path] = [len(expanded_types)] + expanded_types.append(ty) + return expanded_index[path] + + _, rank = desc_info + desc_meta = None + if tensordesc_meta is not None and tensordesc_idx < len(tensordesc_meta): + desc_meta = tensordesc_meta[tensordesc_idx] + abi_expanded_args = _get_tensordesc_abi_expanded_args(rank, desc_meta) + mapped = [] + dtype = desc_info[0] + expanded_types.append(f"*{dtype}") + mapped.append(len(expanded_types) - 1) + if abi_expanded_args != 1: + for _ in range(rank): + expanded_types.append("i32") + mapped.append(len(expanded_types) - 1) + for _ in range(rank): + expanded_types.append("i64") + mapped.append(len(expanded_types) - 1) + expanded_index[path] = mapped + tensordesc_idx += 1 + return mapped + + for top_idx, ty in enumerate(signature_types): + visit(ty, (top_idx, )) + + return expanded_types, expanded_index + + +def _expand_tensordesc_kernel_arg(arg, rank: int, metadata): + if not (hasattr(arg, "base") and hasattr(arg, "shape") and hasattr(arg, "strides")): + raise TypeError("tensor descriptor argument must provide base/shape/strides") + shape = [int(v) for v in arg.shape] + strides = [int(v) for v in arg.strides] + if len(shape) != rank or len(strides) != rank: + raise ValueError( + f"tensor descriptor rank mismatch: expected {rank}, got shape={len(shape)} strides={len(strides)}") + + if metadata is not None and "block_size" in metadata: + block_shape = [int(v) for v in metadata["block_size"]] + else: + block_shape = [int(v) for v in getattr(arg, "block_shape", ())] + if len(block_shape) != rank: + raise ValueError(f"tensor descriptor block rank mismatch: expected {rank}, got {len(block_shape)}") + + if metadata is not None and "elem_size" in metadata: + elem_size = int(metadata["elem_size"]) + elif hasattr(arg.base, "element_size"): + elem_size = int(arg.base.element_size()) + else: + raise TypeError("cannot infer tensor descriptor element size") + + import torch + import triton + + descriptor = torch.empty((64, ), dtype=torch.uint8, device=arg.base.device) + fill_name = f"fill_{rank}d_tma_descriptor" + fill_fn = getattr(triton.runtime.driver.active.utils, fill_name, None) + if fill_fn is None: + raise RuntimeError(f"musa driver utils missing {fill_name}") + + if rank > 5: + raise RuntimeError(f"MUSA tensor descriptor rank {rank} is unsupported in launcher") + fill_fn(arg.base.data_ptr(), *shape, *block_shape, elem_size, descriptor.data_ptr()) + + if hasattr(torch, "musa"): + torch.musa.synchronize() + + abi_expanded_args = _get_tensordesc_abi_expanded_args(rank, metadata) + if abi_expanded_args == 1: + return [descriptor], descriptor + return [descriptor, *shape, *strides], descriptor + + +def _make_tensordesc_cache_key(arg, rank: int, metadata): + base = getattr(arg, "base", None) + if base is None or not hasattr(base, "data_ptr"): + return None + + device = getattr(base, "device", None) + device_type = getattr(device, "type", None) + device_index = getattr(device, "index", None) + + try: + shape = tuple(int(v) for v in arg.shape) + strides = tuple(int(v) for v in arg.strides) + except Exception: + return None + + if metadata is not None and "block_size" in metadata: + block_shape = tuple(int(v) for v in metadata["block_size"]) + else: + try: + block_shape = tuple(int(v) for v in getattr(arg, "block_shape", ())) + except Exception: + return None + + if metadata is not None and "elem_size" in metadata: + elem_size = int(metadata["elem_size"]) + elif hasattr(base, "element_size"): + elem_size = int(base.element_size()) + else: + return None + abi_expanded_args = _get_tensordesc_abi_expanded_args(rank, metadata) + + return ( + int(base.data_ptr()), + device_type, + device_index, + shape, + strides, + block_shape, + elem_size, + int(rank), + abi_expanded_args, + ) + + +def make_launcher(constants, signature, ids, warp_size): + params = [i for i, ty in signature.items() if ty != "constexpr" and i not in constants] + arg_decls = ', '.join(f"{ty_to_cpp_param(signature[i])} arg{i}" for i in params) + + def _parse_type(ty): + if ty[0] == '*': + return "PyObject*" + if ty == "constexpr": + # 3.5 runtime forwards constexpr Python objects in launch args. + # They are compile-time only and should not be interpreted as C scalars. + return "PyObject*" + if ty in ("fp16", "bf16", "fp32", "f32", "fp64"): + return "double" + return ty_to_cpp_param(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "L", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_parse_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + packed_decls = [] + packed_inits = [] + launch_args = [] + for i in params: + ty = signature[i] + if ty[0] == "*": + launch_args.append(f"ptr_info{i}.dev_ptr") + continue + if ty == "fp16": + packed_decls.append(f" uint16_t arg{i};") + packed_inits.append(f" arg{i} = pack_fp16(_arg{i});") + launch_args.append(f"arg{i}") + elif ty == "bf16": + packed_decls.append(f" uint16_t arg{i};") + packed_inits.append(f" arg{i} = pack_bf16(_arg{i});") + launch_args.append(f"arg{i}") + elif ty in ("fp32", "f32"): + packed_decls.append(f" float arg{i} = (float)_arg{i};") + launch_args.append(f"arg{i}") + else: + launch_args.append(f"_arg{i}") + + packed_decls_src = "\n".join(packed_decls) + packed_inits_src = "\n".join(packed_inits) + + src = f""" +#include \"musa.h\" +#include +#include +#include +#include + +static inline uint16_t pack_fp16(double val) {{ + uint16_t result; +#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && \\ + !defined(PYPY_VERSION) + _PyFloat_Pack2(val, (unsigned char *)&result, 1); +#else + PyFloat_Pack2(val, (char *)&result, 1); +#endif + return result; +}} + +static inline uint16_t pack_bf16(double val) {{ + float f32 = (float)val; + uint32_t u32 = *(uint32_t *)&f32; + return (uint16_t)(u32 >> 16); +}} + +static inline void gpuAssert(MUresult code, const char *file, int line) +{{ + if (code != MUSA_SUCCESS) + {{ + const char* prefix = \"Triton Error [MUSA]: \"; + const char* str; + muGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define MUSA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +typedef MUresult (*muLaunchKernelEx_t)(const MUlaunchConfig *config, MUfunction f, void **kernelParams, void **extra); + +static muLaunchKernelEx_t getLaunchKernelExHandle() {{ + void* handle = dlopen(\"libmusa.so\", RTLD_LAZY); + if (!handle) {{ + handle = dlopen(\"libmusa.so.1\", RTLD_LAZY); + }} + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, \"Failed to open libmusa.so or libmusa.so.1\"); + return NULL; + }} + dlerror(); + muLaunchKernelEx_t muLaunchKernelExHandle = (muLaunchKernelEx_t)dlsym(handle, \"muLaunchKernelEx\"); + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, \"Failed to retrieve muLaunchKernelEx from libmusa\"); + return NULL; + }} + return muLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int shared_memory, MUstream stream, MUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + MUdeviceptr global_scratch_ptr = 0; + MUdeviceptr profile_scratch_ptr = 0; + void *params[] = {{ {', '.join([*(f"&arg{i}" for i in params), "&global_scratch_ptr", "&profile_scratch_ptr"]) } }}; + if (gridX*gridY*gridZ > 0) {{ + if (num_ctas == 1) {{ + MUSA_CHECK(muLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else {{ + MUlaunchAttribute launchAttr[2]; + launchAttr[0].id = MU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = num_ctas; + launchAttr[0].value.clusterDim.y = 1; + launchAttr[0].value.clusterDim.z = 1; + launchAttr[1].id = MU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = MU_CLUSTER_SCHEDULING_POLICY_SPREAD; + MUlaunchConfig config; + config.gridDimX = gridX * num_ctas; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + config.blockDimX = {warp_size} * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 2; + static muLaunchKernelEx_t muLaunchKernelExHandle = NULL; + if (muLaunchKernelExHandle == NULL) {{ + muLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + MUSA_CHECK(muLaunchKernelExHandle(&config, function, params, 0)); + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + MUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, \"data_ptr\"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, \"data_ptr method of Pointer object must return 64-bit int\"); + ptr_info.valid = false; + Py_DECREF(ret); + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = muPointerGetAttribute(&dev_ptr, MU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == MUSA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + \"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)\", idx); + ptr_info.valid = false; + }} else if (status != MUSA_SUCCESS) {{ + MUSA_CHECK((MUresult)status); // Catch any other musa API errors + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, \"Pointer argument must be either uint64 or have data_ptr method\"); + ptr_info.valid = false; + return ptr_info; +}} + +static void ensureMusaContext() {{ + MUcontext pctx; + MUSA_CHECK(muCtxGetCurrent(&pctx)); + if (!pctx) {{ + // Ensure device context. + MUdevice device; + MUSA_CHECK(muDeviceGet(&device, 0)); + MUSA_CHECK(muDevicePrimaryCtxRetain(&pctx, device)); + MUSA_CHECK(muCtxSetCurrent(pctx)); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + // ensure musa context is valid before calling any MUSA APIs, e.g. before getPointer + // calls muPointerGetAttributes + ensureMusaContext(); + + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_parse_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + int num_warps, num_ctas, shared_memory; + if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{ + PyErr_SetString(PyExc_TypeError, \"kernel_metadata must be a tuple\"); + return NULL; + }} + + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue(\"(O)\", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {'; '.join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" for i in params if signature[i][0] == "*"])}; +{packed_decls_src} +{packed_inits_src} + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, shared_memory, (MUstream)_stream, (MUfunction)_function{', ' + ', '.join(launch_args) if len(launch_args) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue(\"(O)\", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + "__triton_launcher", + NULL, + -1, + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class MusaLauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + + def cst_key(i): + if isinstance(i, str): + return src.fn.arg_names.index(i) + if isinstance(i, tuple) and len(i) == 1: + return i[0] + return i + + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + + ordered_sig_keys = sorted(signature.keys()) + self._signature_types = [signature[key] for key in ordered_sig_keys] + self._has_structured_args = any(isinstance(ty, tuple) for ty in self._signature_types) + self._has_tensordesc = any( + _parse_tensordesc_type(ty) is not None for ty in self._walk_signature_types(self._signature_types)) + self._needs_runtime_expansion = self._has_structured_args or self._has_tensordesc + self._tensordesc_meta = getattr(metadata, "tensordesc_meta", None) + self._tensordesc_keepalive = [] + self._tensordesc_object_cache = OrderedDict() + self._tensordesc_cache = OrderedDict() + + expanded_signature_types, expanded_index = _expand_signature_tree(self._signature_types, self._tensordesc_meta) + expanded_signature = {idx: ty for idx, ty in enumerate(expanded_signature_types)} + expanded_constants = {} + for key, value in constants.items(): + path = _normalize_arg_path(key) + if path not in expanded_index: + continue + for expanded_pos in expanded_index[path]: + expanded_constants[expanded_pos] = value + + expanded_ids = {"ids_of_const_exprs": tuple()} + if ids["ids_of_const_exprs"]: + expanded_constexpr_ids = [] + for key in ids["ids_of_const_exprs"]: + path = _normalize_arg_path(key) + expanded_constexpr_ids.extend(expanded_index.get(path, ())) + expanded_ids = {"ids_of_const_exprs": tuple(expanded_constexpr_ids)} + + target = getattr(metadata, "target", None) + target_warp_size = getattr(target, "warp_size", None) + warp_size = int(target_warp_size) if target_warp_size else 32 + src = make_launcher(expanded_constants, expanded_signature, expanded_ids, warp_size) + mod = compile_module_from_src( + src=src, + name="__triton_launcher", + include_dirs=_musa_include_dirs(), + library_dirs=_library_dirs(), + libraries=["musa"], + ) + self.launch = mod.launch + + @staticmethod + def _walk_signature_types(signature_types): + for ty in signature_types: + if isinstance(ty, tuple): + yield from MusaLauncher._walk_signature_types(ty) + else: + yield ty + + def _expand_tensordesc_arg(self, arg, ty, tensordesc_idx): + _, rank = _parse_tensordesc_type(ty) + desc_meta = None + if self._tensordesc_meta is not None and tensordesc_idx < len(self._tensordesc_meta): + desc_meta = self._tensordesc_meta[tensordesc_idx] + + cached = None + cache_key = _make_tensordesc_cache_key(arg, rank, desc_meta) + object_cache_key = None + object_ref = None + try: + object_ref = weakref.ref(arg) + object_cache_key = (id(arg), cache_key) + except TypeError: + object_ref = None + + if object_cache_key is not None: + cached = self._tensordesc_object_cache.get(object_cache_key) + if cached is not None: + cached_ref, cached_base_ptr, expanded_arg_values, keepalive = cached + current_base = getattr(getattr(arg, "base", None), "data_ptr", None) + current_base_ptr = int(current_base()) if current_base is not None else None + if cached_ref() is arg and current_base_ptr == cached_base_ptr: + self._tensordesc_object_cache.move_to_end(object_cache_key) + cached = (expanded_arg_values, keepalive) + else: + self._tensordesc_object_cache.pop(object_cache_key, None) + cached = None + + if cached is None: + cached = self._tensordesc_cache.get(cache_key) if cache_key is not None else None + if cached is None: + expanded_arg_values, keepalive = _expand_tensordesc_kernel_arg(arg, rank, desc_meta) + expanded_arg_values = tuple(expanded_arg_values) + if object_cache_key is not None: + self._tensordesc_object_cache[object_cache_key] = ( + object_ref, + int(arg.base.data_ptr()), + expanded_arg_values, + keepalive, + ) + self._tensordesc_object_cache.move_to_end(object_cache_key) + if len(self._tensordesc_object_cache) > _TENSORDESC_CACHE_LIMIT: + self._tensordesc_object_cache.popitem(last=False) + if cache_key is not None: + # Reuse encoded descriptors across launches of the same tensor/view + # so repeated TME kernels do not re-encode and synchronize every + # descriptor argument on the host path. + cached = (expanded_arg_values, keepalive) + self._tensordesc_cache[cache_key] = cached + self._tensordesc_cache.move_to_end(cache_key) + if len(self._tensordesc_cache) > _TENSORDESC_CACHE_LIMIT: + self._tensordesc_cache.popitem(last=False) + else: + expanded_arg_values, keepalive = cached + if cache_key is not None: + self._tensordesc_cache.move_to_end(cache_key) + return expanded_arg_values, keepalive + + def _expand_runtime_arg(self, arg, ty, expanded_kernel_args, launch_keepalive, tensordesc_state): + if isinstance(ty, tuple): + if not isinstance(arg, tuple): + raise RuntimeError("launcher tuple argument does not match structured signature") + if len(arg) != len(ty): + raise RuntimeError("launcher tuple argument arity mismatch") + for child_arg, child_ty in zip(arg, ty): + self._expand_runtime_arg(child_arg, child_ty, expanded_kernel_args, launch_keepalive, tensordesc_state) + return + + desc_info = _parse_tensordesc_type(ty) + if desc_info is None: + expanded_kernel_args.append(arg) + return + + expanded_arg_values, keepalive = self._expand_tensordesc_arg(arg, ty, tensordesc_state[0]) + tensordesc_state[0] += 1 + expanded_kernel_args.extend(expanded_arg_values) + launch_keepalive.append(keepalive) + + def __call__(self, *args, **kwargs): + if not self._needs_runtime_expansion: + self.launch(*args, **kwargs) + return + + # launch(gridX, gridY, gridZ, stream, function, kernel_metadata, + # launch_metadata, launch_enter_hook, launch_exit_hook, *kernel_args) + launch_prefix = args[:9] + kernel_args = args[9:] + if len(kernel_args) != len(self._signature_types): + raise RuntimeError("launcher argument count mismatch while expanding tensor descriptors") + + expanded_kernel_args = [] + launch_keepalive = [] + tensordesc_state = [0] + for arg, ty in zip(kernel_args, self._signature_types): + self._expand_runtime_arg(arg, ty, expanded_kernel_args, launch_keepalive, tensordesc_state) + + self._tensordesc_keepalive.extend(launch_keepalive) + if len(self._tensordesc_keepalive) > 4096: + self._tensordesc_keepalive = self._tensordesc_keepalive[-4096:] + self.launch(*launch_prefix, *expanded_kernel_args, **kwargs) + + +class MusaDriver(DriverBase): + + def __init__(self): + self.utils = MusaUtils() + self.launcher_cls = MusaLauncher + import torch + if not hasattr(torch, "musa"): + raise RuntimeError("torch.musa is not available") + self._torch = torch + + @staticmethod + def is_active(): + try: + import torch + return hasattr(torch, "musa") and torch.musa.is_available() + except Exception: + return False + + def map_python_to_cpp_type(self, ty: str) -> str: + return ty_to_cpp(ty) + + def get_current_target(self): + arch = knobs.runtime.override_arch or os.getenv("TRITON_MUSA_ARCH") or "ph1" + warp_size = 32 + return GPUTarget("musa", arch, warp_size) + + def get_active_torch_device(self): + return self._torch.device("musa", self.get_current_device()) + + def get_current_device(self): + return self._torch.musa.current_device() + + def set_current_device(self, device): + self._torch.musa.set_device(device) + + def get_current_stream(self, device): + stream = self._torch.musa.current_stream(device) + return getattr(stream, "musa_stream", getattr(stream, "cuda_stream", stream)) + + def get_device_interface(self): + return self._torch.musa + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + cache_size = 256 * 1024 * 1024 + return self._torch.empty(int(cache_size // 4), dtype=self._torch.int, device="musa") + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/mthreads/backend/lib/libdevice.31.bc b/third_party/mthreads/backend/lib/libdevice.31.bc new file mode 100644 index 0000000000..6833ca5d01 Binary files /dev/null and b/third_party/mthreads/backend/lib/libdevice.31.bc differ diff --git a/third_party/mthreads/bin/CMakeLists.txt b/third_party/mthreads/bin/CMakeLists.txt new file mode 100644 index 0000000000..d29fec6ecb --- /dev/null +++ b/third_party/mthreads/bin/CMakeLists.txt @@ -0,0 +1,57 @@ +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_executable(triton-opt triton-opt.cpp) + +target_compile_options(triton-opt PRIVATE -fno-rtti -fno-exceptions) +target_link_libraries(triton-opt PRIVATE + ${triton_libs} + # MLIR core + MLIROptLib + MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-opt) + +add_executable(triton-reduce triton-reduce.cpp) +mlir_check_all_link_libraries(triton-reduce) +target_compile_options(triton-reduce PRIVATE -fno-rtti -fno-exceptions) + +target_link_libraries(triton-reduce PRIVATE + ${triton_libs} + # MLIR core + MLIRReduceLib + MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-reduce) + + +add_executable(triton-llvm-opt triton-llvm-opt.cpp) +add_dependencies(triton-llvm-opt intrinsics_gen) +target_compile_options(triton-llvm-opt PRIVATE -fno-rtti -fno-exceptions) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) + + +add_executable(triton-tensor-layout triton-tensor-layout.cpp) +target_compile_options(triton-tensor-layout PRIVATE -fno-rtti -fno-exceptions) +target_link_libraries(triton-tensor-layout PRIVATE + ${triton_libs} + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRTransforms + ) diff --git a/third_party/mthreads/bin/RegisterTritonDialects.h b/third_party/mthreads/bin/RegisterTritonDialects.h new file mode 100644 index 0000000000..8bd8ac4da4 --- /dev/null +++ b/third_party/mthreads/bin/RegisterTritonDialects.h @@ -0,0 +1,129 @@ +#pragma once +#ifndef TRITON_ENABLE_NVIDIA +#define TRITON_ENABLE_NVIDIA 1 +#endif +#ifndef TRITON_ENABLE_MUSA +#define TRITON_ENABLE_MUSA 1 +#endif + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "MTGPUToLLVM/Passes.h" +#include "TritonMUSAGPUToLLVM/Passes.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h" +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h" +#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/InitAllPasses.h" + +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" + +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace mlir { +namespace test { +void registerTestAliasPass(); +void registerTestAlignmentPass(); +void registerTestAllocationPass(); +void registerTestBufferRegionPass(); +void registerTestMembarPass(); +void registerTestLoopPeelingPass(); +namespace proton { +void registerTestScopeIdAllocationPass(); +} // namespace proton +} // namespace test +} // namespace mlir + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::triton::registerTritonPasses(); + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses(); + mlir::triton::instrument::registerTritonInstrumentPasses(); + mlir::triton::gluon::registerGluonPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerRelayoutTritonGPUPass(); + mlir::triton::gpu::registerAllocateSharedMemoryPass(); + mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); + mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); + mlir::registerLLVMDIScope(); + mlir::LLVM::registerInlinerInterface(registry); + mlir::NVVM::registerInlinerInterface(registry); + mlir::registerLLVMDILocalVariable(); + mlir::ub::registerConvertUBToLLVMInterface(registry); + mlir::registerConvertNVVMToLLVMInterface(registry); + mlir::registerConvertMathToLLVMInterface(registry); + mlir::cf::registerConvertControlFlowToLLVMInterface(registry); + mlir::arith::registerConvertArithToLLVMInterface(registry); + + mlir::triton::registerTritonMUSAGPUToLLVMPasses(); + mlir::triton::registerMTGPUToLLVMPasses(); + mlir::registerTritonMUSAGPUPasses(); + + // Plugin passes + if (std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + !filename.empty()) { + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (const char *passName : passNames) + if (auto result = TP.registerPass(passName); !result) + llvm::report_fatal_error(result.takeError()); + + std::vector dialectNames; + if (auto result = TP.getDialectHandles(dialectNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (unsigned i = 0; i < dialectNames.size(); ++i) { + const char *dialectName = dialectNames.data()[i]; + auto result = TP.getDialectPluginInfo(dialectName); + if (!result) + llvm::report_fatal_error(result.takeError()); + ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; + dialectPluginInfo.registerDialectRegistryCallbacks(®istry); + } + } + + registry.insert< + mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, + mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, + mlir::triton::gpu::TritonGPUDialect, + mlir::triton::instrument::TritonInstrumentDialect, + mlir::triton::musa::MUSADialect, mlir::triton::mtgpu::MTGPUDialect, + mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, + mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, + mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect, + mlir::triton::gluon::GluonDialect>(); +} diff --git a/third_party/mthreads/bin/triton-llvm-opt.cpp b/third_party/mthreads/bin/triton-llvm-opt.cpp new file mode 100644 index 0000000000..3beeeabdc1 --- /dev/null +++ b/third_party/mthreads/bin/triton-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple(Triple::normalize(TargetTriple))); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/mthreads/bin/triton-opt.cpp b/third_party/mthreads/bin/triton-opt.cpp new file mode 100644 index 0000000000..2d2570771a --- /dev/null +++ b/third_party/mthreads/bin/triton-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/mthreads/bin/triton-reduce.cpp b/third_party/mthreads/bin/triton-reduce.cpp new file mode 100644 index 0000000000..8235f8fc8c --- /dev/null +++ b/third_party/mthreads/bin/triton-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/mthreads/bin/triton-tensor-layout.cpp b/third_party/mthreads/bin/triton-tensor-layout.cpp new file mode 100644 index 0000000000..6a73e7a8ad --- /dev/null +++ b/third_party/mthreads/bin/triton-tensor-layout.cpp @@ -0,0 +1,237 @@ +#include "RegisterTritonDialects.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/MLIRContext.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace mlir; + +// A CLI tool to print the layout of a tensor. +// +// clang-format off +// Example usage: +// +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view +// +// An input file usually looks like: +// ''' +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// ''' +// clang-format on + +//===--------------------------------------------------------------------===// +// CLI options +//===--------------------------------------------------------------------===// + +static cl::OptionCategory &getPrinterCategory() { + static cl::OptionCategory PrinterCategory( + "Available Print Options", "Options for the tensor layout printing."); + return PrinterCategory; +} + +static cl::opt InputFile( + "i", cl::desc("File that contains the tensor data layout attributes"), + cl::init(""), cl::value_desc("filename"), cl::cat(getPrinterCategory())); + +static cl::opt + OutputFile("o", cl::desc("Output file to write the layout into"), + cl::init(""), cl::value_desc("filename"), + cl::cat(getPrinterCategory())); + +static cl::opt + DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), + cl::value_desc("layout-string"), cl::init(""), + cl::cat(getPrinterCategory())); + +static cl::list + AliasName("alias-names", + cl::desc("A list of alias names (separated by comma) of the " + "layout attributes in the input file"), + cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, + cl::ZeroOrMore, cl::cat(getPrinterCategory())); + +static cl::opt UseHWPointOfView( + "use-hw-view", + llvm::cl::desc( + "Print the layout in hardware point of view. This means the output is " + "from the warp's perspective. Otherwise, the output is from the " + "tensor's perspective (e.g., each element maps to xxx thread)."), + cl::init(false), cl::cat(getPrinterCategory())); + +static cl::opt TensorStr( + "t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), + cl::init(""), cl::value_desc("tensor-type"), cl::cat(getPrinterCategory())); + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +static LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + // DistributedEncodingTrait and SharedEncodingTrait implements the + // toLinearLayout interface. + mlir::Attribute layout = tensorType.getEncoding(); + if (isa(layout)) { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); +} + +static LogicalResult printLayoutFromFile(MLIRContext *context, + StringRef filename, + ArrayRef names, + TensorType tensorTy, + raw_string_ostream &ss) { + if (filename.empty()) + return success(); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + ParserConfig config(context); + auto asmState = AsmParserState(); + + Block parsedIR; + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { + llvm::errs() << "Fail to parse the input file: " << filename << "\n"; + return failure(); + } + + auto printLambda = [&](StringRef name, mlir::Attribute attr) { + ss << "Print layout attribute: #" << name << " = " << attr << "\n"; + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), attr); + + return layoutPrint(rankedTensorTy, ss); + }; + + if (names.empty()) + // If no alias name is given, we print all layout attributes in the file. + for (const auto &def : asmState.getAttributeAliasDefs()) { + if (failed(printLambda(def.name, def.value))) + return failure(); + } + else { + // Print the layout attributes with the given alias names. + for (const auto &alias : names) { + auto def = asmState.getAttributeAliasDef(alias); + if (!def) { + llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; + return failure(); + } + + if (failed(printLambda(alias, def->value))) + return failure(); + + ss << "\n"; + } + } + + return success(); +} + +static LogicalResult printLayoutFromString(MLIRContext *context, + StringRef layoutAttrStr, + TensorType tensorTy, + raw_string_ostream &ss) { + if (layoutAttrStr.empty()) + return success(); + + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); + if (!layout) { + llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; + return failure(); + } + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + + ss << "Print layout attribute: " << layout << "\n"; + + return layoutPrint(rankedTensorTy, ss); +} + +//===--------------------------------------------------------------------===// +// Main entry point +//===--------------------------------------------------------------------===// + +int main(int argc, char **argv) { + cl::HideUnrelatedOptions(getPrinterCategory()); + cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); + + DialectRegistry registry; + registerTritonDialects(registry); + + MLIRContext ctx(registry); + ctx.loadAllAvailableDialects(); + + if (TensorStr.empty()) { + llvm::errs() << "Must specify the tensor type argument\n"; + return 1; + } + + mlir::Type parsedTy = parseType(TensorStr, &ctx); + if (!parsedTy) { + llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr + << "\n"; + return 1; + } + + TensorType tensorType = dyn_cast(parsedTy); + if (!tensorType) { + llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; + return 1; + } + + std::string storage; + raw_string_ostream ss(storage); + + if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) + return 1; + + if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) + return 1; + + if (OutputFile.empty()) { + llvm::outs() << ss.str(); + } else { + std::error_code ec; + llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); + if (ec) { + llvm::errs() << "Error: " << ec.message() << " : unable to open " + << OutputFile << " for output\n"; + return 1; + } + outFs << ss.str(); + outFs.close(); + } + + return 0; +} diff --git a/third_party/mthreads/include/CMakeLists.txt b/third_party/mthreads/include/CMakeLists.txt new file mode 100644 index 0000000000..109c292fea --- /dev/null +++ b/third_party/mthreads/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/mthreads/include/triton/Analysis/Alias.h b/third_party/mthreads/include/triton/Analysis/Alias.h new file mode 100644 index 0000000000..199238bea7 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Alias.h @@ -0,0 +1,96 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +public: + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(AliasInfo::getPessimisticValueState( + lattice->getAnchor()))); + } + + /// Computes if the alloc set of the results are changed. + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/third_party/mthreads/include/triton/Analysis/Allocation.h b/third_party/mthreads/include/triton/Analysis/Allocation.h new file mode 100644 index 0000000000..bd4568a16f --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Allocation.h @@ -0,0 +1,265 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +#include + +namespace mlir { + +namespace triton { +class AllocationAnalysis; + +/// Callback to allow backends to specify target-specific scratch sizes for +/// some operations. +using AllocationAnalysisScratchSizeFn = std::function; + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op); + +unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy, + RankedTensorType dstTy); + +} // namespace triton + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { +public: + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Interval &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Interval &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); +}; + +template Interval(T, T) -> Interval; + +class Allocation { +public: + /// A unique identifier for shared memory buffers + using BufferId = size_t; + using BufferIdSetT = DenseSet; + using FuncAllocMapT = triton::CallGraph::FuncDataMapT; + + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); + + Allocation() = default; + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.at(bufferId).offset; + } + + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.at(bufferId).size; + } + + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + + /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns if the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + + /// Returns mapping from operation to list of live LDS buffers + std::map> getLiveBuffers(); + +private: + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: ttg.local_alloc + /// Scratch: ttg.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + BufferKind kind; + BufferId id; + Operation *owner; + size_t size; + size_t alignment; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size, + size_t alignment = 4, size_t offset = 0) + : kind(kind), id(id), owner(owner), size(size), alignment(alignment), + offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = llvm::MapVector; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + +private: + template + void addBuffer(KeyType &key, Args &&...args) { + BufferId nextId = bufferIdCounter++; + auto [it, inserted] = bufferSet.insert_or_assign( + nextId, BufferT(Kind, nextId, key, std::forward(args)...)); + BufferT *buffer = &it->second; + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = buffer; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = buffer; + } else { + opScratch[key] = buffer; + } + } + + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + +private: + Operation *operation = nullptr; + OpScratchMapT opScratch; + OpScratchMapT opVirtual; + ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; + BufferSetT bufferSet; + size_t sharedMemorySize = 0; + + size_t bufferIdCounter = 0; + + friend class triton::AllocationAnalysis; +}; + +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public triton::CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + ModuleAllocation(ModuleOp moduleOp, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = + triton::defaultAllocationAnalysisScratchSizeFn) + : triton::CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap, scratchSizeGetter); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/third_party/mthreads/include/triton/Analysis/AxisInfo.h b/third_party/mthreads/include/triton/Analysis/AxisInfo.h new file mode 100644 index 0000000000..f252081f43 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,271 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" + +#include + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy, std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + static void initPessimisticStateFromFunc(int argNumber, + FunctionOpInterface funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy); + + static void initDimVectorFromHint(Attribute attr, DimVectorT *vec); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + bool isContiguousDim(const AxisInfo &info, ArrayRef shape, int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + bool isConstantDim(const AxisInfo &info, ArrayRef shape, int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +namespace axisinfo { +using CallbackType = std::function; +} // namespace axisinfo + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp, + axisinfo::CallbackType callback = nullptr) + : CallGraph(moduleOp) { + SmallVector funcs; + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp, callback); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast( + callOp.resolveCallableInTable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getContiguity(Value value); + unsigned getAlignment(Value value); + + // Overloads of the above methods but have separated elementBitWidth to + // calculate the contiguity. These are useful for computing axis info when + // lowering to hardware intrinsics that require a scalar/warp-uniform base ptr + // with separate per lane offsets like AMD buffer operations. + // + // As a concrete example, instead of a single tensor<128x64x!tt.ptr> + // value, now we have two separate values: !tt.ptr for the base pointer + // and tensor<128x64xi32> for the offset. For such cases, we want to compute + // the contiguity on the offsets but use the pointee element type bit width + // instead of the offset element type bit width for alignment + unsigned getContiguity(Value offsetsValue, unsigned elementBitWidth); + unsigned getAlignment(Value offsetsValue, unsigned elementBitWidth); + + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp, + axisinfo::CallbackType callback = nullptr); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/include/triton/Analysis/BufferRegion.h b/third_party/mthreads/include/triton/Analysis/BufferRegion.h new file mode 100644 index 0000000000..7018ea8369 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/BufferRegion.h @@ -0,0 +1,166 @@ +#ifndef TRITON_ANALYSIS_BUFFER_REGION_H +#define TRITON_ANALYSIS_BUFFER_REGION_H + +#include +#include + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Value.h" + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// BufferRegion: a single logical region derived from an alloc +//===----------------------------------------------------------------------===// +struct BufferRegion { + uint32_t baseOffset; + uint32_t length; + + bool operator==(const BufferRegion &other) const { + return baseOffset == other.baseOffset && length == other.length; + } + + bool operator<(const BufferRegion &other) const { + if (baseOffset != other.baseOffset) + return baseOffset < other.baseOffset; + return length < other.length; + } + + template void print(T &os) const { + os << "[" << baseOffset << ", " << length << "]"; + } +}; + +} // namespace mlir::triton + +namespace llvm { + +using namespace mlir::triton; + +template <> struct DenseMapInfo { + static BufferRegion getEmptyKey() { + constexpr uint32_t empty = std::numeric_limits::max(); + return BufferRegion{empty, empty}; + } + static BufferRegion getTombstoneKey() { + constexpr uint32_t tombstone = std::numeric_limits::max() - 1; + return BufferRegion{tombstone, tombstone}; + } + static unsigned getHashValue(const BufferRegion &r) { + return llvm::hash_combine(r.baseOffset, r.length); + } + static bool isEqual(const BufferRegion &a, const BufferRegion &b) { + return a == b; + } +}; + +} // namespace llvm + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// RegionInfo lattice +//===----------------------------------------------------------------------===// +// +// This wraps a set of BufferRegions and provides lattice semantics +// +struct RegionInfo { + using RegionList = llvm::DenseSet; + RegionList regions; + + RegionInfo() = default; + RegionInfo(const RegionList &r) : regions(r) {} + + // Lattice join: union of regions + static RegionInfo join(const RegionInfo &lhs, const RegionInfo &rhs) { + RegionInfo result = lhs; + for (const auto ® : rhs.regions) + if (llvm::find(result.regions, reg) == result.regions.end()) + result.regions.insert(reg); + return result; + } + + bool operator==(const RegionInfo &other) const { + if (regions.size() != other.regions.size()) + return false; + for (auto &r : regions) + if (llvm::find(other.regions, r) == other.regions.end()) + return false; + return true; + } + + template void print(T &os) const { + llvm::SmallVector sortedRegions(regions.begin(), + regions.end()); + llvm::sort(sortedRegions, [](const BufferRegion &a, const BufferRegion &b) { + return a < b; + }); + llvm::interleaveComma(sortedRegions, os, + [&](const BufferRegion &r) { r.print(os); }); + } + + static RegionInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return RegionInfo(); // means "unknown / empty" + } + static RegionInfo getPessimisticValueState(Value) { return RegionInfo(); } +}; + +//===----------------------------------------------------------------------===// +// BufferRegionAnalysis (Sparse Forward Dataflow) +//===----------------------------------------------------------------------===// +// +// Produces a RegionInfo lattice for each MemDesc/ptr-like SSA value, +// and also collects a global list of all discovered BufferRegions. +// +class BufferRegionAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { + +public: + using Base = + dataflow::SparseForwardDataFlowAnalysis>; + using Base::getLatticeElement; + using Base::SparseForwardDataFlowAnalysis; + + enum RegionType { SHARED_MEMORY, TENSOR_MEMORY, BARRIER, NUM_REGION_TYPES }; + + static bool isMemoryAccessOperation(Operation *op); + + // ------------------------------ + // Public API for ConSan + // ------------------------------ + + /// Return the list of all unique (alloc,offset,len) buffer regions + /// discovered by the analysis. + llvm::SmallVector + getAllUsedBufferRegions(RegionType type) const { + return llvm::to_vector(usedBufferRegions[type]); + } + + void calculateUsedBufferRegions(Operation *op); + + // ------------------------------ + // Required overrides + // ------------------------------ + + void setToEntryState(dataflow::Lattice *lat) override { + propagateIfChanged( + lat, lat->join(RegionInfo::getPessimisticValueState(lat->getAnchor()))); + } + + LogicalResult visitOperation( + Operation *op, + llvm::ArrayRef *> operands, + llvm::ArrayRef *> results) override; + + LogicalResult initialize(Operation *top) override; + +private: + // Global registry of all regions + std::set usedBufferRegions[NUM_REGION_TYPES]; + + static void verifyOpIsSupported(Operation *op); +}; + +} // namespace mlir::triton + +#endif // TRITON_ANALYSIS_BUFFER_REGION_H diff --git a/third_party/mthreads/include/triton/Analysis/Membar.h b/third_party/mthreads/include/triton/Analysis/Membar.h new file mode 100644 index 0000000000..7222a2577c --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Membar.h @@ -0,0 +1,263 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" + +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir { + +class OpBuilder; + +/// Callback to allow backend to provide more information on whether a barrier +/// is needed between two operations. Even though two operations access the same +/// shared memory they may not require a barrier in between them. +using MembarFilterFn = + std::function; + +// Represents the access to a slice of an allocation +// It contains information both on physical memory (the interval) and a +// logical view on it (layout, subslice offsets and shape for the access) +struct AllocationSlice { +public: + // Create allocation slice from a value, collecting subslice offsets + AllocationSlice(Value value, Interval allocationInterval); + + // Builder for accesses that represent accesses to the whole + // allocation (scratch buffers, ArriveBarrierOp, ..) + AllocationSlice(Interval interval) + : allocationInterval(interval), accessTy(nullptr) {} + + bool operator<(const AllocationSlice &other) const { + return asTuple() < other.asTuple(); + } + + bool operator==(const AllocationSlice &other) const { + return asTuple() == other.asTuple(); + } + + // Check if a AllocationSlice intersects with another other. + // This happens if their subslice regions intersect in all dimensions. + // Returns true if it can't prove the AllocationSlices are disjoint. + bool intersects(const AllocationSlice &other) const; + + void print(raw_ostream &os) const; + +private: + std::tuple, const void *, llvm::ArrayRef> + asTuple() const { + return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets}; + } + // Offsets from subslice. Empty when offsets are unknown + SmallVector subsliceOffsets; + // The allocated interval for this buffer + Interval allocationInterval; + // Type of the memory descriptor for this access + triton::gpu::MemDescType accessTy; +}; + +struct BlockInfo { + using SliceMapT = std::map>; + + SliceMapT syncReadSlices; + SliceMapT syncWriteSlices; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + for (auto &slice : other.syncReadSlices) + syncReadSlices[slice.first].insert(slice.second.begin(), + slice.second.end()); + + for (auto &slice : other.syncWriteSlices) + syncWriteSlices[slice.first].insert(slice.second.begin(), + slice.second.end()); + return *this; + } + + void dump() { + auto &err = llvm::errs(); + err << "Block Interval:\n"; + err << " Read Intervals:\n"; + for (auto &[slice, ops] : syncReadSlices) { + err << " "; + slice.print(err); + err << " "; + for (auto &op : ops) + err << op->getName() << " "; + err << "\n"; + } + err << " Write Intervals:\n"; + for (auto &[slice, ops] : syncWriteSlices) { + err << " "; + slice.print(err); + err << " "; + for (auto &op : ops) + err << op->getName() << " "; + err << "\n"; + } + } + + /// Returns true if Slices in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other, MembarFilterFn filter, + Allocation *allocation) const { + return /*RAW*/ isIntersected(syncWriteSlices, other.syncReadSlices, filter, + allocation) || + /*WAR*/ + isIntersected(syncReadSlices, other.syncWriteSlices, filter, + allocation) || + /*WAW*/ + isIntersected(syncWriteSlices, other.syncWriteSlices, filter, + allocation); + } + + /// Clears the slices because a barrier is inserted. + void sync() { + syncReadSlices.clear(); + syncWriteSlices.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadSlices == other.syncReadSlices && + syncWriteSlices == other.syncWriteSlices; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const SliceMapT &lhsSlices, const SliceMapT &rhsSlices, + MembarFilterFn filter, Allocation *allocation) const { + for (auto &lhs : lhsSlices) + for (auto &rhs : rhsSlices) + if (lhs.first.intersects(rhs.first)) + for (auto lhsOp : lhs.second) + for (auto rhsOp : rhs.second) + if (!filter || !filter(lhsOp, rhsOp, allocation)) + return true; + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// + +// Common class to analyze membar and fence placement. +class MembarOrFenceAnalysis { + using VirtualBlock = std::pair; + +public: + using FuncBlockInfoMapT = triton::CallGraph::FuncDataMapT; + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory write, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + MembarOrFenceAnalysis() = default; + explicit MembarOrFenceAnalysis(Allocation *allocation, MembarFilterFn filter) + : allocation(allocation), filter(filter) {} + + virtual ~MembarOrFenceAnalysis() = default; + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(FuncBlockInfoMapT &funcBlockInfoMap); + +protected: + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, + SmallVector &successors); + + /// Updates the BlockInfo operation based on the operation. + virtual void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) = 0; + + Allocation *allocation = nullptr; + MembarFilterFn filter = nullptr; +}; + +class MembarAnalysis : public MembarOrFenceAnalysis { +public: + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter) + : MembarOrFenceAnalysis(allocation, filter) {} + + ~MembarAnalysis() override = default; + +private: + /// Updates the BlockInfo operation based on the operation. + virtual void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) override; + + void insertBarrier(Operation *operation, OpBuilder *builder); +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +template +class ModuleMembarOrFenceAnalysis : public triton::CallGraph { +public: + ModuleMembarOrFenceAnalysis(ModuleAllocation *moduleAllocation, + MembarFilterFn filter = nullptr) + : triton::CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation), filter(filter) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + AnalysisType analysis(allocation, filter); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; + MembarFilterFn filter; +}; + +typedef ModuleMembarOrFenceAnalysis ModuleMembarAnalysis; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/mthreads/include/triton/Analysis/Utility.h b/third_party/mthreads/include/triton/Analysis/Utility.h new file mode 100644 index 0000000000..d5f5e5f936 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Utility.h @@ -0,0 +1,399 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" + +namespace mlir { + +inline bool isZeroConst(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) + : op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcTy = firstTy; + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + ArrayRef getSrcShape() { return srcShape; } + + Attribute getSrcLayout() { return srcEncoding; } + + triton::ReduceOp getOperation() { return op; } + + unsigned getThreadOffsetOnReductionAxis(); + + bool isWarpSynchronous(); + + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + + // The shape of the shared memory space needed for the reduction. + SmallVector getScratchRepShape(); + + SmallVector getOrderWithAxisAtBeginning(); + + unsigned getScratchSizeInBytes(); + + bool isReduceWithinCTA(); + + bool isAssociative(); + +private: + triton::ReduceOp op; + RankedTensorType srcTy; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; +}; + +class ScanLoweringHelper { +public: + explicit ScanLoweringHelper(triton::ScanOp op); + // Return true if the lowering of the scan op is supported. + bool isSupported(); + // Return the number of elements per thread along axis dim. + unsigned getAxisNumElementsPerThread(); + // Return the number of elements per thread along non-axis dims. + unsigned getNonAxisNumElementsPerThread(); + // Return the number of threads per warp along non-axis dims. + unsigned getNonAxisNumThreadsPerWarp(); + // Return the flat numbers of threads computing independent scan results. + unsigned getNonAxisNumThreadsPerCTA(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); + // Return the number of blocks along axis dim. + unsigned getAxisNumBlocks(); + // Return the number of blocks along non axis dim. + unsigned getNonAxisNumBlocks(); + // Return the size of the scratch space needed for scan lowering. + unsigned getScratchSizeInBytes(); + // Return the number of elements of the scratch space needed for scan + // lowering. + unsigned getScratchSizeInElems(); + + // Stride between contiguous element along axis dim. + unsigned getAxisElementStride(); + // Stride between contiguous threads along axis dim. + unsigned getAxisThreadStride(); + // Stride between contiguous blocks along axis dim. + unsigned getAxisBlockStride(); + + Location getLoc() { return scanOp.getLoc(); } + unsigned getAxis() { return scanOp.getAxis(); } + bool getReverse() { return scanOp.getReverse(); } + triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; } + llvm::ArrayRef getShape() { return srcShape; } + unsigned getNumOperands() { return scanOp.getNumOperands(); } + SmallVector getElementTypes() { return srcElementTypes; } + SmallVector getOrder() { return order; } + Region &getCombineOp(); + +private: + triton::ScanOp scanOp; + triton::gpu::LinearEncodingAttr srcEncoding; + Attribute legacyEncoding; + llvm::ArrayRef srcShape; + SmallVector srcElementTypes; + SmallVector order; +}; + +// Helper class for lowering `tt.gather` operations. This class shares lowering +// logic between shared memory allocation and LLVM codegen. +class GatherLoweringHelper { +public: + GatherLoweringHelper(triton::GatherOp gatherOp); + + // Get the shared memory scratch size required by this op. + unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); + +private: + triton::GatherOp gatherOp; + RankedTensorType srcTy; + RankedTensorType dstTy; +}; + +// This struct represents the factorization of a warp-local layout conversion +// into three components: a register-only permutation, a lane-only permutation, +// and a set of swaps between lane and register basis vectors. Algebraically, it +// represents the factorization P = P_mixed \circ P_lane \circ P_reg. It is used +// to aid in the implementation of the layout conversion using warp-shuffles. +// +// `pReg` and `pLane` are square layouts each with only one input and output +// dimension. `mixedTranspositions` holds pairs of integers (i, j) +// corresponding to the transposition (r_i l_j) of the i-th register basis +// vector with the j-th lane basis vector along with 16-bit selectors for byte +// permute instructions (where each of the four nybbles is in the range [0, 7]). +// `nPack` gives the number of basis vectors that can be used for register +// packing while ensuring packed elements arrive at the same destination lane. +struct DecomposedWarpConversion { + struct TranspositionInfo { + std::pair transposition; + uint16_t topPreSel = 0x3210; + uint16_t botPreSel = 0x7654; + uint16_t topPostSel = 0x3210; + uint16_t botPostSel = 0x7654; + }; + + triton::LinearLayout pReg, pLane; + SmallVector mixedTranspositions; + int nPack; +}; + +// Produces a decomposition of a permutation describing a warp-local layout +// conversion as described in `DecomposedWarpConversion` above. +// +// This function handles cases where the numbers of register and lane basis +// vectors differ between the two layouts. This is done by padding the smaller +// dimension(s) with zero vectors, ensuring that the layout conversion can be +// represented as a permutation. +DecomposedWarpConversion +getWarpLayoutConvertDecomposition(RankedTensorType srcTy, + RankedTensorType dstTy, int bitwidth); + +// Decomposes a reshape into simpler pieces. +// +// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. +// You might explain what this does as follows. +// +// - Split the first input dimension into [2,2]. +// - Take the remaining two input dimensions, merge them into a single [16] +// dim, and then split that into [8,2]. +// +// In general, a reshape can be described a sequence of smushing one or more +// input dimensions together and then breaking them apart into one or more +// output dimensions. So we could represent the example above as follows. +// +// [ +// ([0], [0, 1]), # input dim [0] -> output dims [0, 1] +// ([1, 2], [2, 3]), # input dims [1, 2] -> output dims [2, 3] +// ] +// +// Notice that the input dims (first tuple elems) appear in sequential order if +// you read left-to-right-top-to-bottom, and so do the output dims. +// +// This function returns the above decomposition. +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); + +// Returns the number of elements in the scratch space needed. +// If shape is empty, it means no shared memory is needed. +unsigned getNumScratchElements(ArrayRef shape); + +bool supportWMMA(triton::DotOp op); + +bool supportMMA(triton::DotOp op, int version); + +bool supportMMA(Value value, int version); + +// Conversion from `srcTy` to `dstTy` involving the minimum amount of data +// transfer provided that both types can be converted to LL (if it can't it'll +// return nullopt). The output will be such that layout.getInDimNames() == +// layout.getOutDimNames() and the conversion will not include kBlock (resp. +// kWarp or kLane) if it can be avoided +triton::LinearLayout minimalCvtLayout(Type srcTy, Type dstTy); + +// Conversion from `srcTy` to `dstTy` only involves reordering of registers. +// There is no need for data exchange across threads, warps, or blocks. +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` involves data exchange across threads +// within a warp. No data exchange across warps or blocks is needed. +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` involves data exchange across threads, +// warps, and possibly blocks. +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); + +/// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +namespace triton { + +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallableInTable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; + +} // namespace triton + +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +bool isCvtWarpSync(const triton::LinearLayout &srcLayout, + const triton::LinearLayout &dstLayout); + +} // namespace mlir + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/third_party/mthreads/include/triton/CMakeLists.txt b/third_party/mthreads/include/triton/CMakeLists.txt new file mode 100644 index 0000000000..27c703b3cf --- /dev/null +++ b/third_party/mthreads/include/triton/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/third_party/mthreads/include/triton/Conversion/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..730f5cadd2 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonGPU) diff --git a/third_party/mthreads/include/triton/Conversion/MLIRTypes.h b/third_party/mthreads/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 0000000000..dd8d4be4c2 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,46 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H +#define TRITON_CONVERSION_MLIR_TYPES_H + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +inline Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +inline Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); } +inline Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } +inline Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Unsigned); +} +inline Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} + +// Float types +inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } + +inline bool isFloat8(Type type) { + return isa(type); +} + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || llvm::isa(type) || + isFloat8(type); +} + +inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h new file mode 100644 index 0000000000..46a06ac65d --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "triton/Analysis/Allocation.h" + +namespace mlir::triton::gpu { + +/// Attach shared memory related attributes to module and operations inside it. +/// This includes total shared memory consumption in module and shared memory +/// offsets of buffers associated with operations. +void attachAllocationSizeAndOffsetAttr(ModuleOp mod, + ModuleAllocation &allocation); + +} // namespace mlir::triton::gpu + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ALLOCATE_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h new file mode 100644 index 0000000000..00ec880890 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h @@ -0,0 +1,27 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +inline std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + return llvm::join(strs.begin(), strs.end(), delimiter); +} + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..93f8374e59 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM) +add_public_tablegen_target(TritonGPUConversionPassIncGen) diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h new file mode 100644 index 0000000000..ee8ff65aa0 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -0,0 +1,206 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { + +namespace gpu { + +Type getElementType(Value value); + +class MultipleOperandsRange + : public iterator_range>::iterator> { + using ContainerT = SmallVector>; + +public: + using iterator_range::iterator_range; + ContainerT::reference operator[](ContainerT::size_type idx) { + return begin()[idx]; + } + ContainerT::const_reference operator[](ContainerT::size_type idx) const { + return begin()[idx]; + } + ContainerT::size_type size() const { return end() - begin(); } +}; + +// Base pattern for elementwise conversion using ConcreteT. Unpacks individual +// elements from a `!llvm.struct` via `llvm.extactvalue`, calls +// ConcreteT::createDestOps on each element, and packs them back into an +// `!llvm.struct` using `llvm.insertvalue`. +// +// Also supports processing the inputs in a vectorized form by consuming and +// producing multiple operand sets in ConcreteT::createDestOps. +template +class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ElementwiseOpConversionBase( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + auto ctx = op.getContext(); + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + RankedTensorType rtType = dyn_cast(result.getType()); + if (!rtType) + // the result must be a tensor + return resultVals; + + // Bail out if we don't have the constancy analysis + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + return resultVals; + SmallVector constancy = axisInfo->getConstancy(); + + if (llvm::all_of(constancy, [](int64_t c) { return c == 1; })) + return resultVals; + + // We zero out the bases that are constant + auto kReg = StringAttr::get(ctx, "register"); + auto ll = toLinearLayout(rtType); + auto dims = to_vector(ll.getOutDimNames()); + auto llReg = ll.sublayout({kReg}, dims); + auto inv = ll.pseudoinvert(); + auto invReg = inv.sublayout(dims, {kReg}); + auto bases_inv = invReg.getBases(); + for (auto [c, d] : llvm::zip(constancy, dims)) { + assert(llvm::isPowerOf2_32(c)); + for (int i = 0; i < llvm::Log2_32(c); i++) { + bases_inv[d][i] = {0}; + } + } + auto invBroadcast = LinearLayout(std::move(bases_inv), invReg.getOutDims(), + /*isSurjective=*/false); + auto cvt = llReg.compose(invBroadcast); + + // Deduplicate the result values + SmallVector outVals(resultVals.size()); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = cvt.apply({{kReg, i}}).begin()->second; + outVals[i] = resultVals[srcIdx]; + } + return outVals; + } + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + // element type + auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + allOperands.resize(subOperands.size()); + for (auto v : llvm::enumerate(subOperands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + + SmallVector resultVals; + for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) { + auto curr = static_cast(this)->createDestOps( + op, adaptor, rewriter, elemTy, MultipleOperandsRange(it, end), loc); + if (curr.size() == 0) + return failure(); + for (auto v : curr) { + if (!static_cast(v)) + return failure(); + resultVals.push_back(v); + } + it += curr.size(); + } + resultVals = maybeDeduplicate(op, resultVals); + Value view = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, view); + + return success(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +// Trivial case where we map elementwise to an existing LLVM operator +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {DestOp::create(rewriter, loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +template +struct ElementwiseToIntrinsicOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseToIntrinsicOpConversion> { + using Base = + ElementwiseOpConversionBase; + using OpAdaptor = typename Base::OpAdaptor; + + using Base::Base; + + explicit ElementwiseToIntrinsicOpConversion( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, StringRef intrinsic, + PatternBenefit benefit = patternBenefitDefault) + : Base(typeConverter, axisAnalysisPass, benefit), intrinsic(intrinsic) {} + + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, elemTy, + operands[0]) + .getResult(0)}; + } + +private: + StringRef intrinsic; +}; + +} // namespace gpu + +} // namespace mlir::triton +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h new file mode 100644 index 0000000000..907d36ed45 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h @@ -0,0 +1,35 @@ +#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H +#define TRITON_CONVERSION_FMA_DOT_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::gpu { + +/// Abstract interface for scalar multiplication of Value vectors. +/// +/// Enable generation of hardware specific code in different backends. +class FMAVectorMultiplier { +public: + /// \returns scalar product of two arrays, plus c: a·b + c + virtual Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) = 0; + + virtual ~FMAVectorMultiplier() = default; +}; + +/// Implements a framework for FMA dot conversion to llvm. +/// +/// This function implements architecture independent part of FMA dot +/// conversion and calls "multiplier" object, which is defined by caller +/// and implements architecture dependant part of conversion. +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier); + +} // namespace mlir::triton::gpu + +#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h new file mode 100644 index 0000000000..2a3a67a594 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -0,0 +1,25 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H + +#include "mlir/Pass/Pass.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton::gpu { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td new file mode 100644 index 0000000000..fa3cc63c72 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -0,0 +1,45 @@ +#ifndef TRITONCOMMONGPU_CONVERSION_PASSES +#define TRITONCOMMONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { + let summary = "Add metadata for shared memory allocation"; + + let description = [{ + This pass uses the `ModuleAllocation` analysis to: + - Annotate modules with an attribute with the amount of shared/local + memory used. + - Annotate operations with an offset into the total shared/local memory. + }]; +} + +def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> { + let summary = "Assign global scratch memory allocation"; + + let description = [{ + Decide on global scratch space memory allocation and assign attributes to each allocation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::ModuleOp"> { + let summary = "Allocate warp groups"; + + let description = [{ + The `tritongpu-allocate-warp-groups` pass performs warpgroup allocation for + a GPU program. When a GPU program contains warp specialization, additional + warps are launched in addition to the "default" warp group. The "default" + warpgroup executes top-level code in a `tt.func` and its size is specified + by the user via the `num_warps` argument. + + This pass analyzes `ttg.warp_specialize` ops in the program and determines + the total number of needed warps, then attaches the range of warp IDs to + each warpgroup function. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 0000000000..b94bc656ef --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,117 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +namespace mlir { +namespace triton { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +constexpr int patternBenefitNvidiaTensorCoreSubviewPattern = 20; + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateBarrierOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool hwNanPropagationSupported, + PatternBenefit benefit); +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 0000000000..66e5136beb --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,115 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton { +enum class ProgramIDDim : uint32_t; + +class TargetInfoBase { +public: + virtual bool supportMaximumMinimum() const = 0; + + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + + virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const = 0; + + // Emit a block/CTA level barrier that guarantees visibility for the + // target address space + virtual void barrier(Location loc, RewriterBase &rewriter, + triton::gpu::AddrSpace targets) const = 0; + // Insert a warp syncronization barrier that also guarantees local address + // space visibility at warp level when supported by the backend. + // Backends that do not support warp-level barriers should conservatively + // emit a block-level barrier with local address space visibility. + virtual void warpSync(Location loc, RewriterBase &rewriter) const = 0; + + // Store/load a value from shared memory, either in the same CTA or, if + // `ctaId` is non-nullopt, in another CTA in the same group. + // + // A target that does not support cross-CTA transfers will assert if ctaId is + // non-nullopt. + // + // Assumes the address is aligned to the width of `val`. + virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const = 0; + virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, Value pred, + Operation *localLoadOp = nullptr) const = 0; + + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const { + storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred); + } + Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred) const { + return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy, + pred); + } + + virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const = 0; + + virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b, + Value selector) const = 0; + + virtual Value programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, ProgramIDDim axis) const = 0; + + virtual bool warpReduce(RewriterBase &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, + unsigned interleave) const = 0; + + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; + // Emits LLVM code with |rewriter| to print a message following the given + // format from the device. |formatStrStart| is the pointer to the start of + // the format string global variable; |args| are the arguments to fill + // placeholders in the format string. + virtual void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args, + ArrayRef isSigned = {}) const = 0; + + // Emits LLVM code with |rewriter| to print a message, particularly useful for + // backend debug. |msg| is the message to print, |args| are the arguments to + // fill placeholders in the |msg|. + // NOTE: This function is used for backend debug. DO NOT DELETE. + // Example use: targetInfo.printf(rewriter,"index: %d, value: %f", {index, + // value}); + virtual void printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + ArrayRef isSigned = {}) const = 0; + + // Emits LLVM code with |rewriter| to perform assertion failure with the given + // |message| from the given |func| in |file|. + virtual void assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const = 0; + + virtual int getSharedAddressSpace() const = 0; + + virtual int getAddressSpace(Attribute addressSpace) const = 0; + + virtual bool supportVectorizedAtomics() const = 0; + + virtual bool supportLdMatrix() const { return false; } + virtual bool supportStMatrix() const { return false; } + virtual bool supportLdStMatrixB8() const { return false; } + virtual bool isCuda() const { return false; } + + // Annotate target specific information to local load operations during + // lowering to LLVM. `llLoadOp` is the generated LLVM load op. + virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp, + Operation *llLoadOp) const {} + + virtual ~TargetInfoBase() {} +}; +} // namespace mlir::triton +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h new file mode 100644 index 0000000000..1adbbee4e3 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -0,0 +1,39 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, + const LowerToLLVMOptions &option, + const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis = nullptr); + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, + const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonTensorType(RankedTensorType type, + const TargetInfoBase &targetInfo); + Type convertMemDescType(triton::gpu::MemDescType type, + const TargetInfoBase &targetInfo); + Type convertAsyncTokenType(triton::gpu::AsyncTokenType type); + + template void convertFP8Type() { + (addConversion([&](T type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }), + ...); + } +}; + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 0000000000..10b60d0a6b --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,652 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/STLExtras.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::LLVM { +using namespace mlir::triton; + +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v); +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); +Value createConstantF16(Location loc, OpBuilder &rewriter, float v); +Value createConstantBF16(Location loc, OpBuilder &rewriter, float v); +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); +Value createConstantF64(Location loc, OpBuilder &rewriter, double v); +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value); +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value); + +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args); +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args); +} // namespace mlir::LLVM + +namespace mlir::triton { + +struct TritonLLVMOpBuilder { + TritonLLVMOpBuilder(Location loc, OpBuilder &builder) + : loc(loc), builder(&builder) {} + + // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive + // Operators + template LLVM::SIToFPOp inttofloat(Args &&...args) { + return LLVM::SIToFPOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::IntToPtrOp inttoptr(Args &&...args) { + return LLVM::IntToPtrOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::PtrToIntOp ptrtoint(Args &&...args) { + return LLVM::PtrToIntOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::ZExtOp zext(Args &&...args) { + return LLVM::ZExtOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::SExtOp sext(Args &&...args) { + return LLVM::SExtOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FPExtOp fpext(Args &&...args) { + return LLVM::FPExtOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FPTruncOp fptrunc(Args &&...args) { + return LLVM::FPTruncOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::TruncOp trunc(Args &&...args) { + return LLVM::TruncOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::UDivOp udiv(Args &&...args) { + return LLVM::UDivOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::SDivOp sdiv(Args &&...args) { + return LLVM::SDivOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::URemOp urem(Args &&...args) { + return LLVM::URemOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::AddOp add(Args &&...args) { + return LLVM::AddOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::SubOp sub(Args &&...args) { + return LLVM::SubOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FAddOp fadd(Args &&...args) { + return LLVM::FAddOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::MulOp mul(Args &&...args) { + return LLVM::MulOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FMulOp fmul(Args &&...args) { + return LLVM::FMulOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FMAOp fma(Args &&...args) { + return LLVM::FMAOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::FNegOp neg(Args &&...args) { + return LLVM::FNegOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::SMaxOp smax(Args &&...args) { + return LLVM::SMaxOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::UMaxOp umax(Args &&...args) { + return LLVM::UMaxOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::MaxNumOp fmax(Args &&...args) { + return LLVM::MaxNumOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::SMinOp smin(Args &&...args) { + return LLVM::SMinOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::UMinOp umin(Args &&...args) { + return LLVM::UMinOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::MinNumOp fmin(Args &&...args) { + return LLVM::MinNumOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::ShlOp shl(Args &&...args) { + return LLVM::ShlOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::LShrOp lshr(Args &&...args) { + return LLVM::LShrOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::AShrOp ashr(Args &&...args) { + return LLVM::AShrOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::AndOp and_(Args &&...args) { + return LLVM::AndOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::XOrOp xor_(Args &&...args) { + return LLVM::XOrOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::OrOp or_(Args &&...args) { + return LLVM::OrOp::create(*builder, loc, std::forward(args)...); + } + LLVM::BitcastOp bitcast(Value val, Type type) { + return LLVM::BitcastOp::create(*builder, loc, type, val); + } + template + LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) { + return LLVM::AddrSpaceCastOp::create(*builder, loc, + std::forward(args)...); + } + template LLVM::GEPOp gep(Args &&...args) { + return LLVM::GEPOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::InsertValueOp insert_val(Args &&...args) { + return LLVM::InsertValueOp::create(*builder, loc, + std::forward(args)...); + } + template LLVM::ExtractValueOp extract_val(Args &&...args) { + return LLVM::ExtractValueOp::create(*builder, loc, + std::forward(args)...); + } + template + LLVM::InsertElementOp insert_element(Args &&...args) { + return LLVM::InsertElementOp::create(*builder, loc, + std::forward(args)...); + } + template + LLVM::ExtractElementOp extract_element(Args &&...args) { + return LLVM::ExtractElementOp::create(*builder, loc, + std::forward(args)...); + } + template LLVM::LoadOp load(Args &&...args) { + return LLVM::LoadOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::StoreOp store(Args &&...args) { + return LLVM::StoreOp::create(*builder, loc, std::forward(args)...); + } + LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) { + return LLVM::FCmpOp::create(*builder, loc, builder->getI1Type(), + LLVM::FCmpPredicate::ogt, lhs, rhs); + } + LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) { + return LLVM::FCmpOp::create(*builder, loc, builder->getI1Type(), + LLVM::FCmpPredicate::olt, lhs, rhs); + } + LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) { + return LLVM::FCmpOp::create(*builder, loc, builder->getI1Type(), + LLVM::FCmpPredicate::oeq, lhs, rhs); + } + template LLVM::ICmpOp icmp_eq(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::eq, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ne(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::ne, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_slt(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::slt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sle(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::sle, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sgt(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::sgt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sge(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::sge, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ult(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::ult, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ule(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::ule, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ugt(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::ugt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_uge(Args &&...args) { + return LLVM::ICmpOp::create(*builder, loc, LLVM::ICmpPredicate::uge, + std::forward(args)...); + } + template LLVM::SelectOp select(Args &&...args) { + return LLVM::SelectOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::AddressOfOp address_of(Args &&...args) { + return LLVM::AddressOfOp::create(*builder, loc, + std::forward(args)...); + } + mlir::triton::gpu::BarrierOp barrier(triton::gpu::AddrSpace addrspace) { + return mlir::triton::gpu::BarrierOp::create(*builder, loc, addrspace); + } + template LLVM::UndefOp undef(Args &&...args) { + return LLVM::UndefOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::ZeroOp null(Args &&...args) { + return LLVM::ZeroOp::create(*builder, loc, std::forward(args)...); + } + template LLVM::CallOp call(Args &&...args) { + return LLVM::CallOp::create(*builder, loc, std::forward(args)...); + } + // Constants + Value int_val(short bitwidth, int64_t val) { + Type ty = builder->getIntegerType(bitwidth); + return LLVM::ConstantOp::create(*builder, loc, ty, + builder->getIntegerAttr(ty, val)); + } + Value i1_val(int64_t val) { return int_val(1, val); } + Value true_val() { return int_val(1, true); } + Value false_val() { return int_val(1, false); } + Value f16_val(float v) { return LLVM::createConstantF16(loc, *builder, v); } + Value bf16_val(float v) { return LLVM::createConstantBF16(loc, *builder, v); } + Value f32_val(float v) { return LLVM::createConstantF32(loc, *builder, v); } + Value f64_val(double v) { return LLVM::createConstantF64(loc, *builder, v); } + Value i8_val(int64_t val) { return int_val(8, val); } + Value i16_val(int64_t val) { return int_val(16, val); } + Value i32_val(int64_t val) { return int_val(32, val); } + Value i64_val(int64_t val) { return int_val(64, val); } + + Location loc; + OpBuilder *builder; +}; + +// This builder combines an IRRewriter and a TritonLLVMOpBuilder into one, +// making it easy to create operations with an implicit location and create LLVM +// operations with shorthands. +class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder { +public: + // Create a builder with an implicit location. Arguments are forwarded to + // IRRewriter's constructor. + template + TritonLLVMIRRewriter(Location loc, Args &&...args) + : IRRewriter(std::forward(args)...), + TritonLLVMOpBuilder(loc, *this) {} + + // Get the implicit location. + Location getLoc() const { return loc; } + // Set the implicit location used to build ops. + void setLoc(Location loc) { this->loc = loc; } + + // Wrapper for op creation that passes an implicit location. + template OpTy create(Args &&...args) { + return OpBuilder::create(loc, std::forward(args)...); + } +}; +} // namespace mlir::triton + +// Types +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define int_ty(width) rewriter.getIntegerType(width) +#define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) +#define ui32_ty rewriter.getIntegerType(32, false) +#define ui64_ty rewriter.getIntegerType(64, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) + +// Attributes +#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) +#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) +#define str_attr(str) ::mlir::StringAttr::get(ctx, (str)) + +namespace mlir { + +// See FuncOpToLLVM.cpp for details about Triton's function calling conventions +constexpr int kProfileScratchBufferOffset = -1; +constexpr int kGlobalScratchBufferOffset = -2; +constexpr int kSharedMemoryOffset = -3; + +namespace triton { + +namespace gpu { + +std::pair, SmallVector> +getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth); + +Type getFunctionType(Type resultType, ValueRange operands); + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname = "", + StringRef libpath = ""); + +// Multiply a square layout with 1 input and output dimension with a vector +Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x); +} // namespace gpu + +} // namespace triton + +namespace LLVM { +using namespace mlir::triton; + +class SharedMemoryObject { +public: + SharedMemoryObject(Value base, Type baseElemType, ArrayRef offsets); + + SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc, + RewriterBase &rewriter); + + SmallVector getOffsets() const { return offsets; } + Value getBase() const { return base; } + Type getBaseElemType() const { return baseElemType; } + + SmallVector getElems() const; + + SmallVector getTypes() const; + + // Returns a mask representing all the bits of the memdesc offsets that + // may be modified by an affine offset coming from a memdesc_subslice. + // The offsets are considered to be in the type of the memdesc. + // For padded layouts, we return the offsets without padding. + static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy); + + // Returns whether the shared memory access had a memdesc_subslice + // that is rank-preserving (soon to be called memdesc_slice) + static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) { + return getMaskSpanOffsets(srcTy) != 0; + } + + Value getShmemOffset(Location loc, RewriterBase &rewriter, + triton::gpu::MemDescType srcTy) const; + Value getShmemAffineBase(Location loc, RewriterBase &rewriter, + triton::gpu::MemDescType srcTy) const; + + // TODO(Keren): deprecate the method once AMD backend has cleaned up + Value getCSwizzleOffset(int dim) const { + assert(dim >= 0 && dim < offsets.size()); + return offsets[dim]; + } + + // TODO(Keren): deprecate the method once AMD backend has cleaned up + Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const; + +private: + Value base; // i32 ptr. The start address of the shared memory object. + Type baseElemType; + SmallVector + offsets; // i32 int. The offsets are zero at the initial allocation. +}; + +Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter); + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape); + +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order); + +// Returns a tuple with the delinearized coordinates and a boolean which is true +// iff the Value is not broadcasted (equivalently, if the value is the "first" +// lane/thread/etc. that holds the given value). In mathy terms, the boolean is +// true if the element is the canonical representative of the class. +std::tuple, Value> +delinearize(RewriterBase &rewriter, Location loc, + triton::gpu::DistributedEncodingTrait layout, + ArrayRef shape, StringAttr dimName, Value linear); + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape); + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + triton::gpu::LinearEncodingAttr encoding, StringAttr dimName); + +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content); + +Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp); + +Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + const TargetInfoBase &targetInfo, + FunctionOpInterface funcOp, Value allocOffset); + +Value getProfileScratchPtr(Location loc, RewriterBase &rewriter, + FunctionOpInterface funcOp); + +Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Operation *op); + +// ----------------------------------------------------------------------- +// MXFP utilities +// ----------------------------------------------------------------------- + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath); + +} // namespace LLVM + +// ----------------------------------------------------------------------- +// Hardware Indices +// ----------------------------------------------------------------------- + +// If an operation is contained within a warp specialize region, this returns +// the warp ID offset of that warpgroup. +std::optional getWarpGroupStartWarpId(Block *block); + +// If an operation is contained within a warp specialize region, this returns +// the thread ID offset of that warpgroup. +std::optional getWarpGroupStartThreadId(Block *block); + +// Returns CTA level thread ID. +Value getThreadId(OpBuilder &rewriter, Location loc); + +// Get the lane ID, which is index of the thread within its warp. +Value getLaneId(OpBuilder &rewriter, Location loc); + +// Get the lane ID and warp ID. +std::pair getLaneAndWarpId(OpBuilder &rewriter, Location loc); + +// ----------------------------------------------------------------------- +// Shared memory utilities +// ----------------------------------------------------------------------- +using LLVM::SharedMemoryObject; +using ::mlir::LLVM::delinearize; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides); + +// "Applies" the given layout by computing layout(indices) and returning the +// resulting Values. +// +// In other words, this generates LLVM-dialect MLIR code to "run" the layout +// function. +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices); + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type); + +// Emit indices calculation within each ConversionPattern, and returns a +// [elemsPerThread X rank] index matrix. +// +// For example, for a thread a owns `elemsPerThread` elements of a tensor with +// type `type` and layout `layout`, the result will contain `elemsPerThread` +// vectors. Each vector contains the SSA values of the indices required to +// access the corresponding element, starting from the inner dimension. +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset); + +// Calculates the required interval chunking and padding logical-shift values +// for shared memory padding, depending on elements' bit width and whether +// offsets count the number of bytes or number of elements. +SmallVector> +getPaddedSharedShifts(Attribute enc, unsigned bitwidth, bool offsetInBytes); + +// Applies padding to base offset values in shared memory. +Value applyPadding(Location loc, RewriterBase &rewriter, Value baseOffset, + ArrayRef> shifts); +uint32_t applyPadding(uint32_t baseOffset, + ArrayRef> shifts); + +// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp +// We might want to merge them at some point, but having to support +// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific +// Lowers to st when valArrays is empty, and to ld when it is not, +// and returns the output values. +// `paddingShifts` encodes shared memory padding if any. +SmallVector +lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, Value smemBase, + ArrayRef> paddingShifts, + Value affineOffset, uint64_t maskSpanAffineOffset, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems = {}, + Operation *localLoadOp = nullptr); + +// Lower an ld/st-like operation given a layout and a callback that creates the +// PTX instruction Lowers to st when valArrays is empty, and to ld when it is +// not, and returns the output values. +// calcPaddedOffset is a lambda that takes a base offset (mlir::Value) +// and computes a new offset (mlir::Value) by applying padding based on +// shared memory layout. +SmallVector lowerLdSt( + Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, Value smemBase, + ArrayRef> paddingShifts, Value affineOffset, + uint64_t maskSpanAffineOffset, Value laneId, Value warpId, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems, + std::function(RewriterBase &, Location, ArrayRef, + Value, int, VectorType)> + lowerInst); + +// Lower local_load/local_store via ld.shared/st.shared +SmallVector +lowerLocalLdSt(Location loc, MLIRContext *ctx, + LinearLayout cvt, // Map from registers to offset + ArrayRef valsArray, // Input for store, empty for load + Type llvmElemTy, triton::gpu::MemDescType srcTy, + SharedMemoryObject smemObj, RewriterBase &rewriter, + const TargetInfoBase &targetInfo, + Operation *localLoadOp = nullptr); + +SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter); + +Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, + ValueRange resultVals, RewriterBase &rewriter, Type type); + +SmallVector unpackLLVector(Location loc, Value llvmVec, + RewriterBase &rewriter); + +Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter); + +std::optional matchAtomicOp(RMWOp atomicOp); + +std::optional getMemoryOrdering(MemSemantic memOrdering); + +llvm::MapVector getAllFreeVarMasks(MLIRContext *ctx); + +llvm::MapVector getFreeVariableMasks(Type type); + +inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) { + return (index & freeVarMask) == 0; +} + +// Certain lowerings may introduce references to function arguments. Keep warp +// group code isolated from above by invoking this function. +void makeAllWarpGroupsIsolatedFromAbove(Operation *op); + +// Set the correct loop annotation on LLVM branch ops. +void fixUpLoopAnnotation(ModuleOp mod); + +void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src, + const TargetInfoBase &targetInfo, + const LLVMTypeConverter *typeConverter, + RewriterBase &rewriter); + +SmallVector inlineRegionImpl(RewriterBase &rewriter, Region ®ion, + ArrayRef args, + mlir::TypeID terminatorTypeId, + Location loc); + +template +SmallVector inlineRegion(RewriterBase &rewriter, Region ®ion, + ArrayRef args, Location loc) { + return inlineRegionImpl(rewriter, region, args, + mlir::TypeID::get(), loc); +} + +void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy, + ConversionPatternRewriter &rewriter, + SmallVector &resultVals, + Type valueElemTy, TritonLLVMOpBuilder &b, + Value threadPred, + const TargetInfoBase &targetInfo, + const LLVMTypeConverter *typeConverter); + +// ----------------------------------------------------------------------- +// FuncOp conversion utilities +// ----------------------------------------------------------------------- +void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result); +triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo); +void handleArgPtrDatatype(triton::FuncOp funcOp, LLVM::LLVMFuncOp &llvmFuncOp); +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.h new file mode 100644 index 0000000000..96a816c609 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.h @@ -0,0 +1,78 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_WARPSPECIALIZEUTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_WARPSPECIALIZEUTILITY_H + +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SetVector.h" +#include + +namespace mlir { +namespace triton { + +// Forward declaration +class TritonLLVMIRRewriter; + +//===----------------------------------------------------------------------===// +// convertOpTypes +//===----------------------------------------------------------------------===// + +/// Convert operand types, region argument types, and result types of a +/// an operation using the provided type converter. This is used for +/// WarpSpecializeOp and related operations during lowering to LLVM. +void convertOpTypes(Operation *op, const TypeConverter &typeConverter); + +//===----------------------------------------------------------------------===// +// elideTrivialCaptures +//===----------------------------------------------------------------------===// + +/// Attempt to eliminate captures by rematerializing trivial computations into +/// each partition region. +void elideTrivialCaptures(LLVM::LLVMFuncOp func, + ArrayRef wsOps); + +//===----------------------------------------------------------------------===// +// lowerWarpSpecializeCommon +//===----------------------------------------------------------------------===// + +/// Phase indicator for register reallocation during warp specialization. +enum class RegisterReallocPhase { + SwitchLoopStart, // Reallocate at the beginning of switch loop + WorkerPartitionStart, // Reallocate at worker partition region start + WorkerPartitionEnd, // Reallocate at worker partition region end + DefaultPartitionStart, // Reallocate at default partition region start + DefaultPartitionEnd // Reallocate at default partition region end +}; + +/// Callbacks for backend-specific operations during warp specialization +/// lowering. +struct WarpSpecializeCallbacks { + /// Create a barrier to synchronize threads across the whole CTA + std::function createAllBarrier; + + /// Reallocate registers. + /// regionNumber is only used for WorkerPartitionStart and WorkerPartitionEnd + /// phases. + std::function + reallocRegisters; +}; + +/// Common implementation of warp specialize lowering. +/// Uses callbacks for backend-specific barrier and register reallocation +/// operations. +LogicalResult lowerWarpSpecializeCommon( + LLVM::LLVMFuncOp func, ArrayRef wsOps, Block *entry, + Block *header, Block *switchLoop, Value wid, MLIRContext *ctx, + unsigned defaultNumWarps, unsigned totalNumWarps, + const TargetInfoBase &targetInfo, const WarpSpecializeCallbacks &callbacks, + unsigned switchLoopBarrierIdx); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_WARPSPECIALIZEUTILITY_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 0000000000..99d90c4d75 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) +add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h new file mode 100644 index 0000000000..054f9ea959 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir::triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" + +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td new file mode 100644 index 0000000000..2449637eb1 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -0,0 +1,56 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + This pass converts the Triton Dialect into the TritonGPU Dialect. + This is a partial conversion that also affects other dialects + (namely `Arith`, `Math`, `SCF` and `CF`). + For these dialects, and many Triton dialect operations the conversions + mainly consists of enhancing the tensor type and the `tt.ptr>` + type with an appropriate layout encoding (these encodings generally + include information on `numWarps`, `threadsPerWarp` and `numCTAs`). + }]; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"target", "target", + "std::string", /*default*/"\"\"", + "the GPU target, e.g., cuda:80, hip:gfx942">, + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps">, + Option<"threadsPerWarp", "threads-per-warp", + "int32_t", /*default*/"32", + "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"enableSourceRemat", "enable-source-remat", + "bool", /*default*/"false", + "enable trivial source rematerialization">, + ]; +} + +def RelayoutTritonGPU : Pass<"relayout-tritongpu", "mlir::ModuleOp"> { + let summary = "relayout pass for `ttg` and `ttng` operations"; + let description = [{ + The `relayout-tritongpu` pass is used during relayout of TTGIR + during warp specialization. Warp specialization may change the number of + warps for a partition, which requires reassigning layouts to all the + operations in the partition. However, those operations may include TritonGPU + and TritonNvidiaGPU dialect operations with specific layout requirements, + so they have to be re-inferred during this pass. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..19ca22ec3b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(TritonInstrument) +add_subdirectory(Gluon) +add_subdirectory(NVGPU) +add_subdirectory(NVWS) diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/CMakeCache.txt b/third_party/mthreads/include/triton/Dialect/Gluon/CMakeCache.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/CMakeCache.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Gluon/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Gluon/IR/CMakeLists.txt new file mode 100644 index 0000000000..8e42fc0904 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS GluonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_mlir_doc(GluonOps GluonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS GluonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=gluon) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=gluon) +add_mlir_doc(GluonDialect GluonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS GluonAttrDefs.td) +mlir_tablegen(GluonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(GluonAttrDefs.cpp.inc -gen-attrdef-defs) + +add_public_tablegen_target(GluonTableGen) diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/Gluon/IR/Dialect.h new file mode 100644 index 0000000000..3004e71a62 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/IR/Dialect.h @@ -0,0 +1,11 @@ +#pragma once +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "triton/Dialect/Gluon/IR/Dialect.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/Gluon/IR/GluonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/Gluon/IR/Ops.h.inc" diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonAttrDefs.td b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonAttrDefs.td new file mode 100644 index 0000000000..f2b0da23a9 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonAttrDefs.td @@ -0,0 +1,23 @@ +#ifndef GLUON_ATTRDEFS +#define GLUON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Gluon/IR/GluonDialect.td" + +def Gluon_AutoEncodingAttr : AttrDef { + let mnemonic = "auto_encoding"; + let attrName = "gluon.auto_encoding"; + let description = [{ + An encoding that is inferred from neighboring ops in the graph. + }]; +} + +def Gluon_CoalescedEncodingAttr : AttrDef { + let mnemonic = "coalesced_encoding"; + let attrName = "gluon.coalesced_encoding"; + let description = [{ + An encoding that is optimized for load/store performance. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonDialect.td b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonDialect.td new file mode 100644 index 0000000000..37e55f12ed --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonDialect.td @@ -0,0 +1,22 @@ +#ifndef GLUON_DIALECT +#define GLUON_DIALECT + +include "mlir/IR/OpBase.td" + +def Gluon_Dialect : Dialect { + let name = "gluon"; + let cppNamespace = "::mlir::triton::gluon"; + let description = [{ + Gluon dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + "mlir::gpu::GPUDialect", + ]; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonOps.td b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonOps.td new file mode 100644 index 0000000000..d268c0e515 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/IR/GluonOps.td @@ -0,0 +1,32 @@ +#ifndef GLUON_OPS +#define GLUON_OPS + +include "triton/Dialect/Gluon/IR/GluonDialect.td" +include "triton/Dialect/Gluon/IR/GluonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" + +class Gluon_Op traits = []> : + Op { +} + +def Gluon_SetAutoLayoutOp : Gluon_Op<"set_auto_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType]> { + let summary = "set auto encoding to a concrete encoding type"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Attribute":$encoding, "Value":$value)> + ]; + + let hasVerifier = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +#endif // GLUON_OPS diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..a2d298d0c1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Gluon) +add_public_tablegen_target(GluonTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/InferLayoutUtils.h b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/InferLayoutUtils.h new file mode 100644 index 0000000000..3cd4b0d508 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/InferLayoutUtils.h @@ -0,0 +1,20 @@ +#ifndef TRITON_DIALECT_GLUON_TRANSFORMS_INFERLAYOUTUTILS_H_ +#define TRITON_DIALECT_GLUON_TRANSFORMS_INFERLAYOUTUTILS_H_ + +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace mlir::triton::gluon { + +LogicalResult +inferLayout(FuncOp func, llvm::function_ref typeCheck, + const SmallVector> &seedEncodings); + +LogicalResult doubleCheckEncodings(ModuleOp &mod, + llvm::function_ref typeCheck); + +} // namespace mlir::triton::gluon + +#endif // TRITON_DIALECT_GLUON_TRANSFORMS_INFERLAYOUTUTILS_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.h new file mode 100644 index 0000000000..353d21e04f --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.h @@ -0,0 +1,13 @@ +#pragma once +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include + +namespace mlir::triton::gluon { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" + +} // namespace mlir::triton::gluon diff --git a/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.td new file mode 100644 index 0000000000..04c75dbdd0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Gluon/Transforms/Passes.td @@ -0,0 +1,54 @@ +#ifndef GLUON_PASSES +#define GLUON_PASSES + +include "mlir/Pass/PassBase.td" + +def GluonResolveAutoEncodingsPass : Pass<"gluon-resolve-auto-encodings", "mlir::ModuleOp"> { + let summary = "Resolve automatic encodings"; + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + ]; +} + +def GluonInferCoalescedEncodingsPass : Pass<"gluon-infer-coalesced-encodings", "mlir::ModuleOp"> { + let summary = "Infer coalesced encodings based on axis analysis"; + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + ]; +} + +def GluonCanonicalize: Pass<"gluon-canonicalize"> { + let summary = "reduced set of simplifications for TTGIR"; + + let description = [{ + The `gluon-canonicalize` pass applies a reduced set of simplification + and canonicalization patterns to the module. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::cf::ControlFlowDialect", + "mlir::scf::SCFDialect", + ]; +} + +def GluonInline: Pass<"gluon-inline"> { + let summary = "reduced set of simplifications for TTGIR"; + + let description = [{ + The `gluon-inline` pass applies a reduced set of simplification + and canonicalization patterns to the module. + }]; + let dependentDialects = []; +} + +def GluonSimplifyControlFlow: Pass<"gluon-slimplify-control-flow"> { + let summary = "simplications for control flow ops"; + + let description = [{ + The `gluon-simplify-control-flow` pass applies a reduced set of + simplification and canonicalization patterns for control flow ops. + }]; + let dependentDialects = []; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..21551d2b63 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS NVGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvg) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(NVGPUDialect NVGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(NVGPUOps NVGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(NVGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td) +mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(NVGPUAttrDefsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h new file mode 100644 index 0000000000..a27b556fed --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h.inc" +#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/NVGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace nvgpu {} // namespace nvgpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td new file mode 100644 index 0000000000..c904824ef0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td @@ -0,0 +1,33 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_ATTRDEFS +#define NVGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "NVGPUDialect.td" + +class NVGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td new file mode 100644 index 0000000000..95ad790379 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td @@ -0,0 +1,40 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_DIALECT +#define NVGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def NVGPU_Dialect : Dialect { + let name = "nvg"; + let cppNamespace = "::mlir::triton::nvgpu"; + + let description = [{ + NVGPU Dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect" + ]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUOps.td new file mode 100644 index 0000000000..9a62ebba43 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -0,0 +1,134 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_OPS +#define NVGPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "NVGPUDialect.td" +include "NVGPUAttrDefs.td" + +def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; +def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; +def LLVM_PointerTensorMemory : LLVM_PointerInAddressSpace<6>; + + +def NVGPU_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def NVGPU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def NVGPU_ScalarLike : AnyTypeOf<[NVGPU_Float, NVGPU_Int]>; + + +def NVGPU_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + +class NVGPU_Op traits = []> : + LLVM_OpBase; + +def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods, + AllTypesMatch<["input", "output"]>]> { + let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings); + let results = (outs LLVM_AnyStruct:$output); + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", + "wgmma layout, either 'row' or 'col'", + [ + I32EnumAttrCase<"row", 0>, + I32EnumAttrCase<"col", 1> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", + "wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'", + [ + I32EnumAttrCase<"s8", 0>, + I32EnumAttrCase<"s32", 1>, + I32EnumAttrCase<"e4m3", 2>, + I32EnumAttrCase<"e5m2", 3>, + I32EnumAttrCase<"f16", 4>, + I32EnumAttrCase<"bf16", 5>, + I32EnumAttrCase<"tf32", 6>, + I32EnumAttrCase<"f32", 7> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; + +def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, + WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); + let results = (outs LLVM_AnyStruct:$res); + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; +} + +def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_LoadAcquireOp : NVGPU_Op<"ld_acquire", [MemoryEffects<[MemRead]>]> { + let arguments = ( + ins LLVM_PointerGlobal:$addr, + Optional:$mask, + NVGPU_MemSemanticAttr:$sem, + NVGPU_MemSyncScopeAttr:$scope + ); + let results = (outs NVGPU_ScalarLike:$result); + let assemblyFormat = "$sem `,` $scope `,` $addr (`,` $mask^)? attr-dict `:` functional-type($addr, $result)"; +} + +def NVGPU_TensorMemoryBaseAddress : NVGPU_Op<"tensor_memory_base", [Pure]> { + let description = [{ + Op to represent base address of tensor memory in a kernel. + This is used to simplify lowering from TritonGPU to LLVM. + }]; + let results = (outs LLVM_PointerTensorMemory:$result); + let assemblyFormat = "attr-dict"; +} + + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/NVWS/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/NVWS/IR/CMakeLists.txt new file mode 100644 index 0000000000..dd7d06b7ba --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/CMakeLists.txt @@ -0,0 +1,24 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS NVWSOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvws) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvws) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=nvws) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=nvws) +add_mlir_doc(NVWSDialect NVWSDialect dialects/ -gen-dialect-doc) +add_mlir_doc(NVWSOps NVWSOps dialects/ -gen-op-doc) +add_public_tablegen_target(NVWSTableGen) + +set(LLVM_TARGET_DEFINITIONS NVWSAttrDefs.td) +mlir_tablegen(NVWSAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(NVWSAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(NVWSAttrEnums.h.inc -gen-enum-decls) +mlir_tablegen(NVWSAttrEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS NVWSOpInterfaces.td) +mlir_tablegen(NVWSOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(NVWSOpInterfaces.cpp.inc -gen-op-interface-defs) + +add_public_tablegen_target(NVWSAttrDefsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/NVWS/IR/Dialect.h new file mode 100644 index 0000000000..2c001b6f7c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/Dialect.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef DIALECT_NVWS_IR_DIALECT_H_ +#define DIALECT_NVWS_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "triton/Dialect/NVWS/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#include "triton/Dialect/NVWS/IR/NVWSAttrEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/NVWS/IR/NVWSAttrDefs.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/NVWS/IR/Types.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/NVWS/IR/NVWSOpInterfaces.h.inc" +#include "triton/Dialect/NVWS/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace nvws {} // namespace nvws +} // namespace triton +} // namespace mlir + +#endif // DIALECT_NVWS_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSAttrDefs.td b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSAttrDefs.td new file mode 100644 index 0000000000..18772f27e0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSAttrDefs.td @@ -0,0 +1,70 @@ +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVWS_ATTRDEFS +#define NVWS_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "NVWSDialect.td" + +class NVWS_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +def NVWS_TypeArray : ArrayOfAttr {} +def NVWS_IntArray : ArrayOfAttr {} + +// Type for synchronization tokens. +def NVWS_TokenLoadTypeAttr : I32EnumAttr< + "TokenLoadType", "", + [ + I32EnumAttrCase<"None", 0, "none">, + I32EnumAttrCase<"AsyncLoadOp", 1, "asyncLoadOp">, + I32EnumAttrCase<"TMALoadOp", 2, "tmaLoadOp">, + I32EnumAttrCase<"LocalStoreOp", 3, "localStoreOp">, + I32EnumAttrCase<"TmemLoadOp", 4, "TmemLoadOp">, + ]>{ + let cppNamespace = "::mlir::triton::nvws"; +} + +def NVWS_AsyncOpAttr: I32EnumAttr< + "AsyncOp", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"TMALoad", 1, "tma_load">, + I32EnumAttrCase<"TC5MMA", 2, "tc5mma">, + I32EnumAttrCase<"TMEMCopy", 3, "tmem_copy">, + I32EnumAttrCase<"CpAsync", 4, "cp_async">, + I32EnumAttrCase<"WGMMA", 5, "wgmma">, + ]> { + let cppNamespace = "::mlir::triton::nvws"; + let genSpecializedAttr = 0; +} + +def NVWS_AsyncOpEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVWS_AsyncOpArrayAttr : TypedArrayAttrBase; + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSDialect.td b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSDialect.td new file mode 100644 index 0000000000..18724a0590 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSDialect.td @@ -0,0 +1,45 @@ +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVWS_DIALECT +#define NVWS_DIALECT + +include "mlir/IR/OpBase.td" + +def NVWS_Dialect : Dialect { + let name = "nvws"; + let cppNamespace = "::mlir::triton::nvws"; + + let description = [{ + Nvidia Warp Specialization Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + ]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOpInterfaces.td b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOpInterfaces.td new file mode 100644 index 0000000000..5826e498c2 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOpInterfaces.td @@ -0,0 +1,37 @@ +#ifndef NVWS_OP_INTERFACES +#define NVWS_OP_INTERFACES + +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + +def NVWS_DescriptorLoadOpInterface : OpInterface<"DescriptorLoadOpInterface", [TT_DescriptorOpInterface]> { + let cppNamespace = "::mlir::triton::nvws"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the transaction counts", + /*retType=*/"int", + /*methodName=*/"getTxCount", + /*args=*/(ins)>, + ]; +} + +def NVWS_ArefStageInterface : OpInterface<"ArefStageInterface"> { + let cppNamespace = "::mlir::triton::nvws"; + + let description = [{ + This interface implements setStage/getStage for aref ops + }]; + + // We can add more methods as needed. + let methods = [ + InterfaceMethod<"Return aref stage", + "::mlir::Value", + "getStage">, + InterfaceMethod<"Set aref stage", + "void", + "setStage", + (ins "::mlir::Value":$stage)>, + ]; +} + +#endif // NVWS_OP_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOps.td b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOps.td new file mode 100644 index 0000000000..efc17cb874 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSOps.td @@ -0,0 +1,342 @@ +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVWS_OPS +#define NVWS_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/ControlFlowInterfaces.td" // RegionBranchOpInterface +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ViewLikeInterface.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "NVWSDialect.td" +include "NVWSTypes.td" +include "NVWSAttrDefs.td" +include "NVWSOpInterfaces.td" + +class NVWS_Op traits = []> : + Op; + +def NVWS_ArefCreateOp : NVWS_Op<"aref.create", [ + RangedTypesMatchWith<"input types match Aref output type", + "result", "buffers", "::llvm::cast($_self).getBaseType()">, Pure]> { + let summary = "Create an asynchronous reference."; + let description = [{ + Create an asynchronous reference. + + Takes as inputs a variadic number of buffers, and returns an ARef. + The inputs are expected to be array-like (i.e., Tensor, MemDesc, etc) + and the first axis of the shape should match between all inputs, representing + multi-buffering of the values. + }]; + let arguments = (ins Variadic:$buffers); + + let results = (outs NVWS_ArefType:$result); + + let assemblyFormat = [{$buffers attr-dict `:` type($result)}]; + let hasVerifier = 1; +} + +def NVWS_ArefBufferOp : NVWS_Op<"aref.buffer", [DeclareOpInterfaceMethods]> { + let summary = "Get buffer from aref"; + + let arguments = (ins NVWS_ArefType:$aref, + TTG_AsyncToken:$token, + Optional:$stage); + let results = (outs Variadic:$buffers); + let assemblyFormat = [{ + $aref (`[` $stage^ `]`)? `,` $token attr-dict + `:` type($aref) `,` type($token) `->` type(results) + }]; + + let builders = [ + OpBuilder<(ins "Value":$aref, "TypeRange":$bufferTypes, "Value":$token), [{ + build($_builder, $_state, bufferTypes, aref, token, Value()); + }]> + ]; +} + +def NVWS_ArefGetEnterOp : NVWS_Op<"aref.get.enter", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let summary = "Enter ArefGet region where the buffer can be used to read data"; + let description = [{ Enter a "region" where you can freely read from the buffer) + These ArefGet "regions" can span multiple iterations. }]; + + let arguments = (ins NVWS_ArefType:$aref, + Optional:$stage, + Optional:$phase); + let results = (outs Variadic:$buffers, + TTG_AsyncToken:$token); + let hasVerifier=1; + let assemblyFormat = [{ + $aref ( `[` $stage^ `,` $phase `]`)? attr-dict + `:` type($aref) `->` type(results) + }]; + + let builders = [ + OpBuilder<(ins "Value":$aref, "TypeRange":$bufferTypes, "Type":$tokenType), [{ + build($_builder, $_state, bufferTypes, tokenType, aref, Value(), Value()); + }]> + ]; +} + +def NVWS_ArefGetExitOp : NVWS_Op<"aref.get.exit", [DeclareOpInterfaceMethods]> { + let summary = "Exit ArefGet region, where the buffer should no longer be used"; + let description = [{ Leave the region where you can freely read from the buffer). + These ArefGet "regions" can span multiple iterations. }]; + + let arguments = (ins NVWS_ArefType:$aref, + TTG_AsyncToken:$token, + Optional:$stage, + NVWS_AsyncOpArrayAttr:$async_ops); + let assemblyFormat = [{ + $aref (`[` $stage^ `]`)? `,` $token $async_ops attr-dict + `:` type($aref) `,` type($token) + }]; + + let builders = [ + OpBuilder<(ins "Value":$aref, "Value":$token, "ArrayAttr":$async_ops), [{ + build($_builder, $_state, aref, token, Value(), async_ops); + }]> + ]; +} + +def NVWS_ArefPutEnterOp : NVWS_Op<"aref.put.enter", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let summary = "Enter ArefPut region where the buffer can be used to read data"; + let description = [{ Enter a "region" where you can freely write to the buffer) + These ArefPut "regions" can span multiple iterations. }]; + + let arguments = (ins NVWS_ArefType:$aref, + Optional:$stage, + Optional:$phase); + let results = (outs Variadic:$buffers, + TTG_AsyncToken:$token); + let hasVerifier=1; + let assemblyFormat = [{ + $aref ( `[` $stage^ `,` $phase `]`)? attr-dict + `:` type($aref) `->` type(results) + }]; + + let builders = [ + OpBuilder<(ins "Value":$aref, "TypeRange":$bufferTypes, "Type":$tokenType), [{ + build($_builder, $_state, bufferTypes, tokenType, aref, Value(), Value()); + }]> + ]; +} + +def NVWS_ArefPutExitOp : NVWS_Op<"aref.put.exit", [DeclareOpInterfaceMethods]> { + let summary = "Exit ArefPut region, where the buffer should no longer be used"; + let description = [{ Leave the region where you can freely write to the buffer). + These ArefPut "regions" can span multiple iterations. }]; + + let arguments = (ins NVWS_ArefType:$aref, + TTG_AsyncToken:$token, + Optional:$stage, + NVWS_AsyncOpArrayAttr:$async_ops); + let assemblyFormat = [{ + $aref (`[` $stage^ `]`)? `,` $token $async_ops attr-dict + `:` type($aref) `,` type($token) + }]; + + let builders = [ + OpBuilder<(ins "Value":$aref, "Value":$token, "ArrayAttr":$async_ops), [{ + build($_builder, $_state, aref, token, Value(), async_ops); + }]> + ]; +} + +def NVWS_WarpGroupOp : NVWS_Op<"warp_group", [ + RecursiveMemoryEffects, RecursivelySpeculatable, +]> { + let summary = "Container Op for Warp Specialization"; + let description = [{ + Higher level container for Warp Specialization Analysis. + + Contains a variadic number warp groups, with + the number of warps in each group, plus a region to hold the + computation for that warp group. + + The results of this op, if any, are those of the first region, as returned by + nvws.warp_group.yield op. + + nvws.warp_group should be lowered to ttg.warp_specialize + before execution. + }]; + + let arguments = (ins DenseI32ArrayAttr:$numWarps); + let results = (outs Variadic:$results); + let regions = (region VariadicRegion>:$partitionRegions); + let hasVerifier=1; + let hasCustomAssemblyFormat = 1; +} + +def NVWS_WarpGroupYieldOp : NVWS_Op<"warp_group.yield", [ + Pure, Terminator, ReturnLike, HasParent<"WarpGroupOp">, + DeclareOpInterfaceMethods +]> { + let summary = "yield from the first region of `nvws.warp_group`"; + let description = [{ + This op is equivalent to ttg.warp_yield op for ttg.warp_specialize op. + + TODO: Decide if we should move nvws.warp_group to TritonGPU, or continue to + have TritonGPU depend on NVWS. In the former case, this op can be removed. + The latter one involves a circular dependency between TritonGPU and NVWS. + }]; + + let arguments = (ins Variadic:$values); + + let assemblyFormat = "($values^)? attr-dict (`:` type($values)^)?"; +} + +def NVWS_WarpGroupReturnOp : NVWS_Op<"warp_group.return", [ + Pure, Terminator, HasParent<"WarpGroupOp"> +]> { + let summary = "Terminator for a warp group region"; + let description = [{ + Warp groups are expected to return values via referential modification + of their inputs. Thus, the warp_group.return op takes no values to + return from the warp group. + }]; + + let assemblyFormat = "attr-dict"; +} + +def NVWS_CreateTokenOp : NVWS_Op<"create_token"> { + let summary = "Create a token to be used for synchronizations in communication channels"; + let description = [{ A token will be used by the producer and consumer to synchronize. + The producer will acquire and hold the token, until it has filled the buffers, + and signal the waiting consumer. + The consumer will hold the token until it has consumed the buffers, + and will signal the waiting producer trying to acquire the token. + }]; + + let results = (outs TensorOf<[NVWS_TokenType]>:$result); + + let arguments = (ins I32Attr:$numBuffers, NVWS_TokenLoadTypeAttr:$loadType); + + let builders = [OpBuilder<(ins "uint32_t":$numBuffers, "triton::nvws::TokenLoadType":$loadType)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def NVWS_ProducerAcquireOp : NVWS_Op<"producer_acquire"> { + let summary = "Producer acquires a token to fill buffers"; + let description = [{ The producer will try to acquire the token prior to filling + the buffers. If the buffers are not ready to be filled, the producer will wait to be + signalled by the consumer which finishes consuming the buffers and + releases the token. + }]; + + let arguments = (ins TensorOf<[NVWS_TokenType]>:$token, I32:$idx, I1:$phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def NVWS_ProducerCommitOp : NVWS_Op<"producer_commit"> { + let summary = "Producer commits the buffer changes"; + let description = [{ The producer will release the token and signal the consumer + that the buffers are ready to be consumed. + }]; + + let arguments = (ins TensorOf<[NVWS_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def NVWS_ConsumerWaitOp : NVWS_Op<"consumer_wait"> { + let summary = "Consumer awaits buffer readiness"; + let description = [{ The consumer will wait for the buffer to be ready + to be consumed. If the buffers are not ready, the consumer will wait to be + signalled by the producer which finishes filling the buffers and + releases the token. + }]; + + let arguments = (ins TensorOf<[NVWS_TokenType]>:$token, I32:$idx, I1: $phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def NVWS_ConsumerReleaseOp : NVWS_Op<"consumer_release"> { + let summary = "Consumer releases the token"; + let description = [{ The consumer will release the token and signal the producer + that the buffers are ready to be filled. + }]; + + let arguments = (ins TensorOf<[NVWS_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +def NVWS_DescriptorLoadOp : NVWS_Op<"descriptor_load", [NVWS_DescriptorLoadOpInterface]> { + let summary = "Load from descriptor and store into shared memory"; + let description = [{ + This op behaves exactly like the op with the same name in Triton Dialect, but the result of the load is stored into shared memory. + The execution is still synchronous. + }]; + let arguments = (ins + Arg]>:$desc, + Variadic:$indices, + I32Attr:$txCount, + Arg]>:$result, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let assemblyFormat = [{ + $desc `[` $indices `]` $txCount $result + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type(operands) + }]; +} + +def NVWS_DescriptorGatherOp : NVWS_Op<"descriptor_gather", [NVWS_DescriptorLoadOpInterface]> { + let summary = "gather multiple rows from a descriptor into shared memory"; + let description = [{ + This op behaves exactly like the op with the same name in Triton Dialect, but the result of the load is stored into shared memory. + The execution is still synchronous. + }]; + + let arguments = (ins + Arg]>:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + I32Attr:$txCount, + Arg]>:$result + ); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` $txCount $result + attr-dict `:` type(operands) + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSTypes.td b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSTypes.td new file mode 100644 index 0000000000..1bbce58c9d --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/IR/NVWSTypes.td @@ -0,0 +1,51 @@ +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NWVS_TYPES +#define NWVS_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "NVWSDialect.td" + +class NVWS_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def NVWS_ArefType : NVWS_TypeDef<"Aref", "aref"> { + let summary = "Asynchronous Reference"; + let description = [{ + A meta-type that holds an asynchronous reference to an underlying Type. + + Can wrap multiple underlying values simultaneously. + + Useful for syncing asynchronous operations while doing transformations such + as pipelining and warp specialization. Lowers to the underlying type, and + operations that use this should insert appropriate barriers during lowering. + }]; + let parameters = (ins "TypeArrayAttr":$baseType); + let assemblyFormat = "`<` $baseType `>`"; +} + +def NVWS_TokenType : NVWS_TypeDef<"Token", "token">; + +#endif // NVWS_TYPES diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..c3d83b30cd --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVWSTransforms) +add_public_tablegen_target(NVWSTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.h new file mode 100644 index 0000000000..be838930ce --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef DIALECT_NVWS_TRANSFORMS_PASSES_H_ +#define DIALECT_NVWS_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +} // namespace triton +} // namespace mlir +#endif // DIALECT_NVWS_TRANSFORMS_PASSES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.td new file mode 100644 index 0000000000..06865f1d46 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/NVWS/Transforms/Passes.td @@ -0,0 +1,190 @@ +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVWS_PASSES +#define NVWS_PASSES + +include "mlir/Pass/PassBase.td" + +def NVWSLowerWarpGroup : Pass<"nvws-lower-warp-group", "mlir::ModuleOp"> { + let summary = "Convert nvws.warp_group to ttg.warp_specialize."; + + let description = [{ + Convert nvws.warp_group to ttg.warp_specialize. + + If the first group of nvws.warp_group matches the global + ttg.num_warps, it will be come the default region of ttg.warp_specialize. + If not, the ttg.warp_specialize default region will be empty, and all + warp groups will become isolated regions. + }]; + + let dependentDialects = [ + "mlir::triton::nvws::NVWSDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +def NVWSAssignStagePhase : Pass<"nvws-assign-stage-phase", "mlir::ModuleOp"> { + let summary = "Assign buffer stage to nvws.aref.*."; + + let description = [{ + Assign buffer stage & phase to nvws.aref.* + + The pass will assign buffer stage to each aref op, and phase for enter ops. + }]; + + let dependentDialects = [ + "mlir::triton::nvws::NVWSDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def NVWSLowerAref : Pass<"nvws-lower-aref", "mlir::ModuleOp"> { + let summary = "Convert nvws.aref.* to ttng.*barrier* ops."; + + let description = [{ + Convert nvws.aref.* to ttng.*barrier* ops. + + The pass will convert each aref to a matched value and barrier set, + and will determined appropriate waits/signalling for values being + "empty" or "full" from the use/def chain of aref get/put. + + This lowering may yield non-ideal parallelism in certain cases, + which will be optimized by follow up peephole passes. + }]; + + let dependentDialects = [ + "mlir::triton::nvws::NVWSDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"numStages", "num-stages", "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def NVWSInsertAref: Pass<"nvws-insert-aref", "mlir::ModuleOp"> { + let summary = "Insert arefs between producer and consumer partitions."; + + let description = [{ + To automate barrier synchronizations between producer and consumer + partitions, arefs are introduced in the IR. This pass handles tensor, + scalar, and SMEM producers and consumers. + + Specifically, for producer partitions, a producing operation is + wrapped in an ArefPutEnterOp and ArefPutExitOp pair. A descriptor load + op is replaced with the corresponding NVWS op, to store its result + into the SMEM buffer owned by an aref. For consumer partitions, a reference + to the original SMEM buffer is replaced with an indirection via ArefGetEnterOp on + the SMEM buffer owned by an aref. ArefGetExitOp is placed after the post-dominant + consumer operation. + }]; + + let dependentDialects = [ + "mlir::triton::nvws::NVWSDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def NVWSInsertTmemAref: Pass<"nvws-insert-tmem-aref", "mlir::ModuleOp"> { + let summary = "Insert tmem arefs between producer and consumer partitions."; + + let description = [{ + Insert arefs when TMEM partition ownership changes. + + In contrast to the InsertAref pass, this pass uses ArefPut/ArefGet as ping-pong + ownership transfer between two groups. Currently, this pass limits ownership + of a specific TMEM buffer to no more than two groups. + }]; + + let dependentDialects = [ + "mlir::triton::nvws::NVWSDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def NVWSHoistTmemStore: Pass<"nvws-hoist-tmem-store", "mlir::ModuleOp"> { + let summary = "Hoist tmem store before the inner loop to the top level if possible."; + + let description = [{ + The HoistTMEMAlloc pass in TritonGPU, when applied to nested loops, puts the hoisted alloc and store inside the outer loop. + Given such input IR, this pass tries to hoist alloc and store across all loop nests, while threading the token variable appropriately. + + For example, this IR + + scf.for ... { + %result, %token = ttng.tmem_alloc {ttg.partition = array} + %16 = ttng.tmem_store %zero, %result[%token], %true {ttg.partition = array} + scf.for ... iter_args(%useD = %false, %arg9 = %16){ + ... + %28 = ttng.tc_gen5_mma %lhs, %rhs, %result[%arg9], %useD, %true {ttg.partition = array} + ... + scf.yield {ttg.partition = array} %true, %28 + } + }{tt.warp_specialize, ...} + + is transformed into + + %result, %token = ttng.tmem_alloc %zero {ttg.partition = array} + scf.for ... iter_args(%token_arg = %token) { // The token variable is threaded across loops + %res = scf.for ... iter_args(%useD = %false, %arg9 = %token_arg){ + ... + %28 = ttng.tc_gen5_mma %lhs, %rhs, %result[%arg9], %useD, %true {ttg.partition = array} + ... + scf.yield {ttg.partition = array} %true, %28 + } + yield %res#0 // Note there is now an explicit yield op + }{tt.warp_specialize, ...} + + This is valid, since the useD flag initialized to false means that the zero clear of the accumulator can be skipped. + If the inner loop does not execute at all, we would be returning the accumulator filled with zeros for all output tiles. + + This transformation is strictly an optimization. Note that the tmem_store before the inner loop is assigned to the partition 0, while the accumulator + is used by the MMA op in partition 1. This would result in an aref being created for this use of TMEM, along with put enter/exit and get enter/exit in + the two partitions, meaning an additional synchronization before the inner loop just to clear the accumulator. When the useD flag is intialized to false, + hoisting the tmem_store to the top level eliminates such unnecessary synchronization. + + Cares must be taken in such hoisting across loop nests. This transformation is valid as long as all instances of the inner loop execute + the same number of times - either at least once or none. This does not hold when the number of iterations of the inner loop depends on an outer-loop + iterator. But even in the presece of a variable iteration count, hoisting is still valid if we can statically prove that the inner loop executes + at least once. A Triton kernel can use tl.assume op to assert a certain bound on a variable. Given an inner loop with a variable iteration count, + this pass checks if there is an assumption on the bounds of the loop which allows us to prove that the loop executes at least once. + Hoisting is enabled in such cases. + }]; + + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +#endif // NVWS_PASSES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 0000000000..5ed50b3a23 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,31 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) +mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(TritonTableGen) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 0000000000..db54dcd708 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,121 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef shape, + ArrayRef order, Attribute &resultEncoding, + std::optional loc) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional loc) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional loc) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional loc) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape using legacy layouts. This is not always + // possible (in which case we fallback to using LinearLayouts) + // In the future we'll always use LinearLayouts + virtual LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + // Check if two layouts are structurally the same, even if their names are + // different + virtual LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, std::optional loc) const = 0; + + virtual LogicalResult + inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; + + virtual LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, + Attribute &outEnc, bool fwdInference, + std::optional loc) const = 0; +}; + +class DialectVerifyTensorLayoutInterface + : public DialectInterface::Base { +public: + DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op, + function_ref emitError) const = 0; + + virtual LogicalResult + verifyMemDescLayout(Attribute layout, Type type, Operation *op, + function_ref emitError) const = 0; +}; + +// Descriptor gather and scatter have restrictions on the tile sizes. +LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType, + ShapedType resultType, + ShapedType indicesType); +LogicalResult verifyDescriptorLoadStoreOp(Operation *op, + TensorDescInterface desc, + ShapedType tensor); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/DiscardableAttributes.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/DiscardableAttributes.h new file mode 100644 index 0000000000..68908fa926 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/DiscardableAttributes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_ + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton { + +// Filter out attributes from the given operation that are not present in +// the allowList. +[[nodiscard]] SmallVector +filterDiscardableAttrs(Operation *op, ArrayRef allowList); + +} // namespace mlir::triton +#endif // TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 0000000000..fb5951fa5c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,45 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/InliningUtils.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final; + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final; + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final; +}; + +} // namespace mlir::triton + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/OpInterfaces.h new file mode 100644 index 0000000000..326f876e1c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -0,0 +1,24 @@ +#ifndef TRITON_IR_OP_INTERFACES_H_ +#define TRITON_IR_OP_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { + +namespace triton { + +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op); + +LogicalResult verifyDotOpInterface(Operation *op); + +} // namespace impl + +} // namespace triton +} // namespace mlir + +#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc" + +#endif // TRITON_IR_OP_INTERFACES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h new file mode 100644 index 0000000000..d6e4bb523b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h @@ -0,0 +1,132 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple +int constexpr maxTensorNumElements = 1048576; + +LogicalResult verifyTensorSize(Operation *op); +LogicalResult verifyTensorLayouts(Operation *op); + +LogicalResult verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType = false); +LogicalResult verifyEquivalentType(Type typeA, Type typeB); +LogicalResult +verifySameOperandsAndResultEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult verifySameLoadStoreOperandsShape(Operation *op); + +LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op); + +} // namespace impl + +template +class TensorSizeTrait : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorSize(op); + } +}; + +// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are +// valid. +template +class VerifyTensorLayoutsTrait + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorLayouts(op); + } +}; + +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; + +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); + } +}; + +template +class SameLoadStoreOperandsShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsShape(op); + } +}; + +template +class SameLoadStoreOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsAndResultShape(op); + } +}; + +template +class SameLoadStoreOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op, + /*allowTensorPointerType=*/true); + } +}; + +template +class SameLoadStoreOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding( + op, /*allowTensorPointerType=*/true); + } +}; + +// This trait indicates that regions in the op may execute concurrently with +// each other. +template +struct AsyncRegions : public TraitBase {}; + +// Marker trait for wait ops that thread selected operands through SSA results +// without changing their types. Utility/layout passes can rebuild these ops +// generically when operand types are rewritten. +template +struct PassthroughWaitLike + : public TraitBase {}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 0000000000..5a76a1d7b1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,154 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_DescriptorReduceKindAttr : I32EnumAttr< + "DescriptorReduceKind", "", + [ + I32EnumAttrCase<"ADD", 1, "add">, + I32EnumAttrCase<"MIN", 2, "min">, + I32EnumAttrCase<"MAX", 3, "max">, + I32EnumAttrCase<"INC", 4, "inc">, + I32EnumAttrCase<"DEC", 5, "dec">, + I32EnumAttrCase<"AND", 6, "and">, + I32EnumAttrCase<"OR", 7, "or">, + I32EnumAttrCase<"XOR", 8, "xor">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee">, + I32EnumAttrCase<"BF16x3", 3, "bf16x3">, + I32EnumAttrCase<"BF16x6", 4, "bf16x6"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16">, + I32EnumAttrCase<"FP16", 6, "fp16"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td new file mode 100644 index 0000000000..d0e25946b5 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -0,0 +1,60 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "tt"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arith: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... + * StructuredControlFlow: + * for, if, while, yield, condition + * ControlFlow: + * br, cond_br + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect", + "ub::UBDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + + static TritonDialect *getLoaded(MLIRContext *ctx) { + return ctx->getLoadedDialect(); + } + static TritonDialect *getLoaded(Operation *op) { + return getLoaded(op->getContext()); + } + }]; + + let discardableAttrs = (ins + "::mlir::IntegerAttr":$num_stages, + "::mlir::IntegerAttr":$latency, + "::mlir::IntegerAttr":$self_latency + ); + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "triton/Dialect/Triton/IR/TritonTypes.td" + + +#endif // TRITON_DIALECT diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 0000000000..03fca43758 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,31 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; +def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">; +def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; +def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; +def AsyncRegions : NativeOpTrait<"AsyncRegions">; +def PassthroughWaitLike : NativeOpTrait<"PassthroughWaitLike">; + +// A trait equivalent to InferTypeOpAdaptor, but that checks for structural +// equivalence of the layouts of the result rather than just layout equality. +def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{ + static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { + if (lhs.size() != rhs.size()) + return false; + return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) { + auto [lhs, rhs] = tup; + return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs)); + }); + } +}]>; + +#endif // TRITON_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td new file mode 100644 index 0000000000..5cb7f8f333 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -0,0 +1,118 @@ +#ifndef TRITON_OP_INTERFACES +#define TRITON_OP_INTERFACES + +include "mlir/IR/OpBase.td" + + +def TransposeOpInterface : OpInterface<"TransposeOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a transpose. + It provides methods to access common properties such as the order attribute + and the source operand. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the source operand of the transposition.", + /*retType=*/"::mlir::Value", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get the order of the transposition.", + /*retType=*/"::mlir::ArrayRef", + /*methodName=*/"getOrder", + /*args=*/(ins)> + ]; + + let verify = [{ + return ::mlir::triton::impl::verifyTransposeOpInterface($_op); + }]; +} + +def DotOpInterface : OpInterface<"DotOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a dot product. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the LHS A tensor", + /*retType=*/"::mlir::Value", + /*methodName=*/"getA", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get the RHS B tensor", + /*retType=*/"::mlir::Value", + /*methodName=*/"getB", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get the output tensor", + /*retType=*/"::mlir::Value", + /*methodName=*/"getD", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Verify the dimensions of the A and B DotOp operands.", + /*retType=*/"bool", + /*methodName=*/"verifyDims", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Verify the dimensions of the DotOp output.", + /*retType=*/"bool", + /*methodName=*/"verifyOutputDims", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImpl=*/ [{ + auto aTy = cast($_op.getA().getType()); + auto bTy = cast($_op.getB().getType()); + auto cTy = cast($_op->getOperand(2).getType()); + auto dTy = cast($_op.getD().getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + return cShape[cShape.size() - 2] == aShape[aShape.size() - 2] && + cShape[cShape.size() - 1] == bShape[aShape.size() - 1]; + }]> + ]; + + let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }]; +} + +def TT_DescriptorOpInterface : OpInterface<"DescriptorOpInterface"> { + let description = [{ + Common interface to get the descriptor argument from an operation on tensor descriptors. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the descriptor", + /*retType=*/"::mlir::TypedValue", + /*methodName=*/"getDesc", + /*args=*/(ins)>, + ]; +} + +def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterface", [TT_DescriptorOpInterface]> { + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get Source tensor", + /*retType=*/"::mlir::TypedValue", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get mutable source tensor", + /*retType=*/"::mlir::OpOperand&", + /*methodName=*/"getSrcMutable", + /*args=*/(ins)>, + ]; +} + + +#endif // TRITON_OP_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 0000000000..91f8aff7a6 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1413 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let hasVerifier = 1; +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatLike:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; + + let hasFolder = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; + let hasFolder = 1; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let hasFolder = 1; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = (ins + Arg, "", [MemWrite]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins + TT_AtomicRMWAttr:$atomic_rmw_op, + Arg, MemWrite]>:$ptr, + TT_Type:$val, + Optional:$mask, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope + ); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)"> +]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins + Arg, MemWrite]>:$ptr, + TT_Type:$cmp, + TT_Type:$val, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope + ); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_UnsplatOp : TT_Op<"unsplat", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "convert a tensor with a single element to a scalar"; + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src)"; + let hasVerifier = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let builders = [ + OpBuilder<(ins "ArrayRef":$shape, "Value":$src, + CArg<"bool", "false">:$allowReorder)> + ]; + + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +// Cat is not pure because it may reorder elements. +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + Pure, SameTypeOperands]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; + let hasVerifier = 1; +} + +def TT_SplitOp : TT_Op<"split", [ + Pure, + InferTypeOpWithLayoutEquivalence, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + TransposeOpInterface, + InferTypeOpWithLayoutEquivalence, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_Tensor:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee, bf16x3, bf16x6. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + bf16x3: implement the 3xBF16 trick. For more info see the pass in F32DotTC.cpp + bf16x6: implement the 6xBF16 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($a, $a_scale), scale($b, $b_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$a, + RankedTensorOf<[TT_Float,I8]>:$b, + TT_FloatTensor:$c, + Optional>:$a_scale, + Optional>:$b_scale, + TT_ScaleDotElemTypeAttr:$a_elem_type, + TT_ScaleDotElemTypeAttr:$b_elem_type, + BoolAttr:$fastMath, + DefaultValuedAttr:$lhs_k_pack, + DefaultValuedAttr:$rhs_k_pack + ); + + let results = (outs TT_FloatTensor:$d); + + let assemblyFormat = [{ + $a (`scale` $a_scale^)? `,` $b (`scale` $b_scale^)? `,` $c + `lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict + `:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d) + }]; + let hasVerifier = 1; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsShape, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Map Elementwise op +// +def TT_MapElementwiseOp: TT_Op<"map_elementwise", [SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + RecursiveMemoryEffects]> { + let summary = "Map a scalar subregion over a tensor"; + let arguments = (ins Variadic:$srcs, I32Attr:$pack); + let results = (outs Variadic:$result); + let regions = (region AnyRegion:$scalarOp); + let hasVerifier = 1; + let hasRegionVerifier = 1; +} + +def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return", + [HasParent<"MapElementwiseOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for map elementwise operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "attr-dict ($result^ `:` type($result))?"; +} + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure, + TypesMatchWith<"mask type matches src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">]> { + let summary = "return a histogram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src, + Optional:$mask); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src (`,` $mask^)? attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + + The `efficient_layout` attribute is set when the compiler has determined an + optimized layout for the operation, indicating that it should not be + changed. + }]; + + let arguments = (ins + TT_Tensor:$src, + TT_IntTensor:$indices, + I32Attr:$axis, + UnitAttr:$efficient_layout + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$args, + DenseI32ArrayAttr:$isSigned + ); + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor and a message string. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// +// Make Tensor Descriptor Op +// +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { + let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; + + let description = [{ + `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, + and returns a descriptor object which can be used to load/store from the tensor in global memory. + }]; + + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + DefaultValuedAttr:$padding + ); + + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; + + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape, "bool":$isSignedInteger, + "triton::PaddingOption":$padding)> + ]; + + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } + }]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", cast(callee)); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [ + AffineScope, AutomaticAllocationScope, CallableOpInterface, + FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, + HasParent<"ModuleOp"> +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, mlir::ValueRange()); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc` is a tensor descriptor object. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + Arg]>:$desc, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc)) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc` is a tensor descriptor object. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + Arg, MemWrite]>:$desc, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) + }]; + let hasVerifier = 1; +} + +def TT_DescriptorReduceOp : TT_Op<"descriptor_reduce", [TT_DescriptorStoreLikeOpInterface]> { + let summary = "performs a reducing store operation based on a descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc` is a tensor descriptor object. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + TT_DescriptorReduceKindAttr:$kind, + Arg, MemWrite]>:$desc, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $kind `,` $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) + }]; + let hasVerifier = 1; +} + +def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [TT_DescriptorOpInterface]> { + let summary = "gather multiple rows from a descriptor into a single tensor"; + let description = [{ + The `tt.descriptor_gather` op will be lowered to NVIDIA TMA + gather operations on targets that support it. + + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The descriptor block must have 1 row and the indices must be a 1D tensor. + Accordingly, the result is a 2D tensor multiple rows. + }]; + + let arguments = (ins + Arg]>:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLikeOpInterface]> { + let summary = "scatter multiple rows to a descriptor from a single tensor"; + let description = [{ + The `tt.descriptor_scatter` op will be lowered to NVIDIA TMA + scatter operations on targets that support it. + + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The descriptor block must have 1 row and the indices must be a 1D tensor. + Accordingly, the result is a 2D tensor multiple rows. + }]; + + let arguments = (ins + Arg, MemWrite]>:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + TT_Tensor:$src + ); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` `,` $src + attr-dict `:` type(operands) + }]; + + let hasVerifier = 1; +} + + +#endif // Triton_OPS diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td new file mode 100644 index 0000000000..a12c6b6fe9 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td @@ -0,0 +1,53 @@ +#ifndef TRITON_TYPE_INTERFACES +#define TRITON_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// TensorDescInterface +//===----------------------------------------------------------------------===// + +def TT_TensorDescInterface : TypeInterface<"TensorDescInterface"> { + let cppNamespace = "::mlir::triton"; + + let description = [{ + Common interface for tensor descriptor types. + + This interface provides a unified API for different tensor descriptor + implementations (e.g., tiled TensorDescType, im2col TensorDescIm2ColType). + All tensor descriptors share the concept of a "block type" which describes + the shape and element type of the data block being accessed. + + Concrete implementations: + - TensorDescType (Triton dialect): Basic tiled tensor descriptor + - TensorDescIm2ColType (TritonNvidiaGPU dialect): Im2col tensor descriptor + with additional convolution parameters + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Returns the block type of the tensor descriptor", + /*retType=*/"mlir::RankedTensorType", + /*methodName=*/"getBlockType", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/"Returns the block type with signless integer element type", + /*retType=*/"mlir::RankedTensorType", + /*methodName=*/"getSignlessBlockType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImpl=*/[{ + auto resTy = $_type.getBlockType(); + if (auto intTy = llvm::dyn_cast(resTy.getElementType())) { + auto width = resTy.getElementTypeBitWidth(); + auto signlessTy = mlir::IntegerType::get($_type.getContext(), width); + resTy = resTy.clone(signlessTy); + } + return resTy; + }] + >, + ]; +} + +#endif // TRITON_TYPE_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 0000000000..fad0ef9865 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,134 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def I4 : I<4>; +def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; + +// Type constraint for any type implementing TensorDescInterface +def TT_AnyTensorDescType : Type< + CPred<"::mlir::isa<::mlir::triton::TensorDescInterface>($_self)">, + "tensor descriptor type", + "::mlir::triton::TensorDescInterface" +>; + +// Result type of MakeTensorDescriptor +def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", [TT_TensorDescInterface]> { + let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; + + let description = [{ + A portable abstraction for TMA descriptors. + This is the base tensor descriptor type for tiled tensor memory access. + + For specialized access patterns like im2col, see TensorDescIm2ColType + in the TritonNvidiaGPU dialect. + }]; + + let parameters = (ins + "RankedTensorType":$blockType + ); + + let assemblyFormat = "`<` $blockType `>`"; + + let builders = [ + // Builder with signedness + TypeBuilder<(ins "RankedTensorType":$blockType, "bool":$isSigned), [{ + if (auto intTy = llvm::dyn_cast(blockType.getElementType())) { + auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; + auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem); + blockType = blockType.clone(elemTy); + } + return Base::get($_ctxt, blockType); + }]>, + ]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h new file mode 100644 index 0000000000..9c652a66a2 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h @@ -0,0 +1,44 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc" + +#include "triton/Dialect/Triton/IR/Types.h.inc" + +namespace mlir { + +namespace triton { + +bool isTensorPointerType(Type type); + +bool isTensorOrTensorPointerType(Type type); + +unsigned getPointeeBitWidth(Type type); + +Type getPointeeType(Type type); + +Type getPointerType(Type type, int addressSpace = 1); + +int getAddressSpace(Type type); + +Type getElementTypeOfTensorPointerType(Type type); + +Type getI1SameShape(Type type); + +Type getI32SameShape(Type type); + +Type getPointerTypeSameShape(Type type); + +Type getPointerTypeToElement(Type type); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h new file mode 100644 index 0000000000..67f6eebe92 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h @@ -0,0 +1,214 @@ +#ifndef TRITON_IR_UTILITY_H_ +#define TRITON_IR_UTILITY_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include +#include + +namespace mlir { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + +// Returns the bit width of a type, treating pointer-like types as 64-bit. +// This handles LLVM dialect pointer types. +inline int getIntOrFloatOrPtrBitWidth(Type type) { + if (isa(type)) + return kPtrBitWidth; + return type.getIntOrFloatBitWidth(); +} + +template SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const auto &i : in) + out.push_back(T(i)); + return out; +} + +template +SmallVector convertType(const VecU &in) { + return convertType(ArrayRef(in)); +} + +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies()); +} +template auto product(const VecT &vec) { + return product(llvm::ArrayRef(vec)); +} + +// TODO(jlebar): Rename to ceilOfRatio. +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + +/// Get the highest power of 2 divisor of an integer. +template constexpr T highestPowOf2Divisor(T n) { + // When n is 0 or min, return the highest power of 2. The min case is handled + // separately to avoid underflow when T is a signed integer. Technically + // in that case the correct divisor is -n, but this value is outside the + // range of possible values, so we take the next best alternative. + if (n == 0 || n == std::numeric_limits::min()) { + return (static_cast(1) << (sizeof(T) * 8 - 2)); + } + return (n & (~(n - 1))); +} + +/// Get the next power of 2 for an integer (or the integer itself if it is a +/// power of 2). +template T nextPowOf2(T n) { + if (n == 0) { + return 1; + } + n--; + for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) { + n |= n >> i; + } + return n + 1; +} + +namespace triton { + +// Many functions here have two overloads, fn(ArrayRef) and fn(const VecT&). +// This is helpful because C++ won't both convert a vector to ArrayRef *and* +// infer the proper type T in one step. So without the second overload, we +// would have to explicitly convert most arguments to ArrayRef at the callsite. + +template +SmallVector applyPermutation(ArrayRef vec, ArrayRef permutation) { + static_assert(std::is_integral_v); + assert(vec.size() == permutation.size()); + + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (U i = 0; i < static_cast(sortedPerm.size()); i++) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret; + ret.reserve(vec.size()); + for (const U &i : permutation) { + ret.push_back(vec[i]); + } + return ret; +} + +template +auto applyPermutation(const VecT &vec, const PermT &permutation) { + return applyPermutation(ArrayRef(vec), ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector inversePermutation(ArrayRef permutation) { + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (int i = 0; i < sortedPerm.size(); ++i) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + ret[permutation[i]] = i; + } + return ret; +} + +template +[[nodiscard]] auto inversePermutation(const VecT &permutation) { + return inversePermutation(ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector gather(ArrayRef elems, ArrayRef indices) { + SmallVector ret; + ret.reserve(indices.size()); + for (const U &i : indices) { + ret.push_back(elems[i]); + } + return ret; +} + +template +[[nodiscard]] auto gather(const VecT &elems, const IdxT &indices) { + return gather(ArrayRef(elems), ArrayRef(indices)); +} + +// Is `vec` [0, 1, ..., n]? Returns true on empty list. +template bool isIota(ArrayRef vec) { + static_assert(std::is_integral_v); + for (size_t i = 0; i < vec.size(); ++i) { + if (vec[i] != static_cast(i)) { + return false; + } + } + return true; +} + +template bool isIota(const VecT &vec) { + return isIota(ArrayRef(vec)); +} + +// Is `vals` some permutation of the numbers 0..(vals.size()-1)? +template bool isPermutationOfIota(ArrayRef vals) { + SmallVector sorted(vals); + llvm::sort(sorted); + return isIota(sorted); +} + +template bool isPermutationOfIota(const VecT &vec) { + return isPermutationOfIota(ArrayRef(vec)); +} + +// Is `vec` [i, i+1, ..., i+n]? Returns true on empty list. +template bool isConsecutive(ArrayRef vec) { + static_assert(std::is_integral_v); + for (int i = 1; i < vec.size(); i++) { + if (vec[i] != vec[i - 1] + 1) { + return false; + } + } + return true; +} + +template bool isConsecutive(const VecT &vec) { + return isConsecutive(ArrayRef(vec)); +} + +template auto seq(T start, T end, T step) { + auto len = ceil(end - start, step); + return llvm::map_range(llvm::seq(0, len), + [=](T i) { return start + i * step; }); +} + +// Combine the current mask with the given predicate. +Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, + Value pred); + +// Get the value of the induction variable at the end of the loop. +Value getLastInductionValue(OpBuilder &b, scf::ForOp loop); + +MakeTensorPtrOp getMakeTensorPtrOp(Value v); + +bool isHostSideDescriptor(Value v); + +bool isKernel(FunctionOpInterface funcOp); + +unsigned getBitwidth(RankedTensorType ty); + +// If the value "anchor" is compared against a statically-computed bound, return +// inclusive lower and upper bounds lb <= anchor <= ub. Depending on the +// comparison operator, one of the bounds is a computed one while the other is +// derived from the data type of anchor. +std::optional getBoundFromCmpOp(arith::CmpIOp cmpOp, + Value anchor); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/ArithTypeConversion.h b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/ArithTypeConversion.h new file mode 100644 index 0000000000..1e772f330b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/ArithTypeConversion.h @@ -0,0 +1,18 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_ +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::triton { + +/** + * @brief Provides helper patterns for converting arith operations using a type + * converter. + * + * Note at of the time of writing this isn't provided in upstream mlir. + */ +void populateArithTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns); + +} // namespace mlir::triton + +#endif // TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..372a9ec11e --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/FunctionTypeConversion.h b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/FunctionTypeConversion.h new file mode 100644 index 0000000000..77940bb417 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/FunctionTypeConversion.h @@ -0,0 +1,19 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_ +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::triton { + +/** + * @brief Provides helper patterns for converting triton function operations + * using a type converter. + * + * Note we cannot use upstream passes for this because they are unaware of + * tt.call and tt.return. + */ +void populateFunctionTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns); + +} // namespace mlir::triton + +#endif // TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/LoopPeeling.h b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/LoopPeeling.h new file mode 100644 index 0000000000..38efd6b134 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/LoopPeeling.h @@ -0,0 +1,18 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" + +namespace mlir { +namespace triton { + +// Peel the single last iteration of the loop. +void peelLoopEpilogue( + scf::ForOp forOp, + function_ref + processPeeledOp = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 0000000000..5d254bf830 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,19 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 0000000000..3744f8ad07 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,93 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + This pass aims to optimize the five following patterns: + - `dot(a, b, 0) + c => dot(a, b, c)` + + - `addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))` + + - `select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other)` + + - `broadcast(constant) => reshaped_constant` + - `torch.sum(x[:,:,None].expand(-1,-1,n) * y[None,:,:].expand(m,-1,-1),1) + => dot(x,y,splat(0))` + }]; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def TritonReorderBroadcast : Pass { + let summary = "Moves broadcast and splat after elementwise operations"; + let description = [{ + The purpose of this pass is to transform: + - `elementwise(broadcast(a)) => broadcast(elementwise(a))` + - `elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))` + In the event of a match, the broadcast (or splat) operation is delayed + and performed after the ElementWise operation. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorDescriptorToPointer : Pass { + let summary = "Rewrite load/stores of tensor descriptors into pointer load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_descriptor` into pointer semantics. After + this pass, `tt.make_tensor_descriptor` will disappear, and it generates logics to compute the pointer/mask/other + for each load/store. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonLoopUnroll : Pass { + let summary = "Loop unroller"; + let description = [{ + The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations + the loop should be unrolled. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonLoopInvariantCodeMotion : Pass { + let summary = "MLIR's LICM plus hoist load ops out of loops with masks."; + let description = [{ + This pass uses MLIR's LICM pass as base. Additionally, it hoists load ops + out of loops that consists of pure/read-only ops. For scf.for loops, it + generates a trip-count check. For scf.while loops, it clones the condition + from the before body. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonLoopAwareCSE : Pass<"triton-loop-aware-cse", "mlir::ModuleOp"> { + let summary = "CSE within loop bodies"; + + let description = [{ + The `triton-loop-aware-cse` pass performs recursive common subexpression + elimination within loop bodies. Unlike regular CSE, which is a single-pass + greedy algorithm, this pass can recursively eliminate loop iteration + arguments and subcomputations that always have the same value. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 0000000000..5c5e81ee73 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,13 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +#include "triton/Dialect/TritonGPU/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h new file mode 100644 index 0000000000..ca72ac417a --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h @@ -0,0 +1,11 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_CGAENCODINGATTR_H_ +#define TRITON_DIALECT_TRITONGPU_IR_CGAENCODINGATTR_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Tools/LinearLayout.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h.inc" +#undef GET_ATTRDEF_CLASSES + +#endif // TRITON_DIALECT_TRITONGPU_IR_CGAENCODINGATTR_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.td new file mode 100644 index 0000000000..fb495c4d5d --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CGAEncodingAttr.td @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// CGA encoding attribute definition emitted early to break interface cycles. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_CGAENCODING_ATTR_TD +#define TRITONGPU_CGAENCODING_ATTR_TD + +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td" + +//===----------------------------------------------------------------------===// +// CGA Layout +//===----------------------------------------------------------------------===// + +def CGAEncodingAttr : TritonGPU_Attr<"CGAEncoding", "cga_encoding"> { + let parameters = (ins LinearLayoutParam:$linearLayout); + + let description = [{ +Describes how blocks (CTAs) in a cooperative thread array (CGA) map onto logical +tensor dimensions. The `LinearLayout` maps from `block` into `dim0`, `dim1`... + }]; + + let extraClassDeclaration = [{ + // Map with empty bases and dims [dim0, dim1, ...] + static CGAEncodingAttr get1CTALayout(MLIRContext *context, int rank); + // Map with bases = [[1,], [2,], ..., [numCTAs/2]] into dim0 + static CGAEncodingAttr get1DLayout(MLIRContext *context, int numCTAs); + // Legacy, we should kill this! Note that it is not true in general that + // fromSplitParams(enc.getCTAsPerCGA(), enc.getCTASplitNum(), enc.getCTAOrder()) == enc!! + static CGAEncodingAttr fromSplitParams(MLIRContext *context, + ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, + ArrayRef CTAOrder); + + unsigned getRank() const { return getLinearLayout().getNumOutDims(); } + SmallVector getCTAsPerCGA() const; + SmallVector getCTASplitNum() const; + SmallVector getCTAOrder() const; + }]; + + let genVerifyDecl = 1; +} + +#endif // TRITONGPU_CGAENCODING_ATTR_TD diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..8b44463001 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,40 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrImpls.td) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUEnums.td) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonGPUOpsEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS CGAEncodingAttr.td) +mlir_tablegen(CGAEncodingAttr.h.inc -gen-attrdef-decls) +add_public_tablegen_target(TritonGPUCGAAttrIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) +mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(TritonGPUTypeInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(TritonGPUOpInterfacesIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h new file mode 100644 index 0000000000..47f7dd0ebf --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -0,0 +1,321 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Traits.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#include + +// LinearLayoutCache Utils +using CacheKey = std::tuple, mlir::Attribute>; + +namespace llvm { +template size_t hash_value(const std::vector &vec) { + return hash_combine_range(vec.begin(), vec.end()); +} +} // namespace llvm + +namespace std { +template <> struct hash { + size_t operator()(const CacheKey &key) const noexcept { + using llvm::hash_value; + size_t seed = 0; + std::apply( + [&seed](const auto &...elems) { + ((seed = llvm::hash_combine(seed, hash_value(elems))), ...); + }, + key); + return seed; + } +}; +} // namespace std + +namespace mlir::triton::gpu { + +constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg"; +constexpr static char AttrNumWarpsName[] = "ttg.num-warps"; +constexpr static char AttrNumCTAsName[] = "ttg.num-ctas"; +constexpr static char AttrTargetName[] = "ttg.target"; +constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp"; +// FIXME: rename to match above +constexpr static char kPartitionAttrName[] = "ttg.partition"; +constexpr static char kPartitionOutputsAttrName[] = "ttg.partition.outputs"; +constexpr static char kPartitionStagesAttrName[] = "ttg.partition.stages"; +constexpr static char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag"; + +// Find the contextual number of warps on which this operation is executed. +int lookupNumWarps(Operation *op); +int lookupNumWarps(Region *region); +// Try to find the contextual number of warps on which this operation is +// executed. Returns nullopt if a warp size cannot be find. This is used for +// verifiers. +std::optional maybeLookupNumWarps(Operation *op); + +// FIXME: Make this API and that of maybeLookupNumWarps consistent! +// Utility to find the number of threads per warp +int lookupThreadsPerWarp(OpBuilder &rewriter); +int lookupNumCTAs(OpBuilder &rewriter); +int lookupNumCTAs(Operation *op); + +template class Cache { +public: + std::optional get(const Key &key) { + std::shared_lock lock(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + return std::nullopt; + } + + void set(Key key, Value result) { + std::scoped_lock lock(mutex); + cache.emplace(std::move(key), std::move(result)); + } + +private: + std::unordered_map cache; + llvm::sys::SmartRWMutex mutex; +}; + +using LinearLayoutCache = Cache; +using LinearEncodingCache = Cache; +} // namespace mlir::triton::gpu + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + +namespace mlir::triton::gpu { +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +// Convert a distributed layout to a linear encoding +LinearEncodingAttr toLinearEncoding(RankedTensorType type); +LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout, + ArrayRef shape); + +unsigned getTotalElemsPerThread(Type type); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape); + +SmallVector getElemsPerThread(Type type); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector getWarpsPerCTA(Attribute layout, + ArrayRef tensorShape); +inline SmallVector getWarpsPerCTA(RankedTensorType type) { + return getWarpsPerCTA(type.getEncoding(), type.getShape()); +} + +// Returns the number of contiguous elements of the logical tensor that each +// thread has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getContigPerThread(RankedTensorType tensorType); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector getThreadsPerWarp(Attribute layout, + ArrayRef shape); +inline SmallVector getThreadsPerWarp(RankedTensorType type) { + return getThreadsPerWarp(type.getEncoding(), type.getShape()); +} + +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(DistributedEncodingTrait layout, + ArrayRef shape); +inline SmallVector getOrder(RankedTensorType type) { + return getOrder(cast(type.getEncoding()), + type.getShape()); +} + +SmallVector getOrder(SharedEncodingTrait layout, + ArrayRef shape); +inline SmallVector getOrder(MemDescType type) { + return getOrder(cast(type.getEncoding()), + type.getShape()); +} +inline SmallVector getOrder(TensorOrMemDesc type) { + if (auto memDesc = dyn_cast(type)) { + return getOrder(memDesc); + } else { + auto tensorTy = cast(type); + return getOrder(tensorTy); + } +} + +// To be removed once we implement arbitrary swizzled layouts +// It chooses heuristically an order for the memory layout in which to save +// a distributed layout taking into account the order of the elements +// and the threads. +SmallVector getOrderForMemory(DistributedEncodingTrait layout, + ArrayRef shape); +inline SmallVector getOrderForMemory(RankedTensorType type) { + return getOrderForMemory(cast(type.getEncoding()), + type.getShape()); +} +inline SmallVector getOrderForMemory(TensorOrMemDesc type) { + if (auto memDesc = dyn_cast(type)) { + return getOrder(memDesc); + } else { + auto tensorTy = cast(type); + return getOrderForMemory(tensorTy); + } +} + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +SmallVector getWarpOrder(DistributedEncodingTrait layout, + ArrayRef shape); +inline SmallVector getWarpOrder(RankedTensorType type) { + return getWarpOrder(cast(type.getEncoding()), + type.getShape()); +} + +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +SmallVector getThreadOrder(DistributedEncodingTrait layout, + ArrayRef shape); +inline SmallVector getThreadOrder(RankedTensorType type) { + return getThreadOrder(cast(type.getEncoding()), + type.getShape()); +} + +CGAEncodingAttr getCGALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +// Returns the "logical" shape per CTA. +// When shape and CTASplitNum have different number of dimensions, we assume +// only the last N between common dimensions are split. +// Example1: shape = [2, 4, 8], CTASplitNum = [2, 2], ret = [2, 2, 4]. +// It can be caused by pipelining. +// Example2: shape = [2, 4], CTASplitNum = [2, 2, 2], ret = [1, 2]. +// It can be caused by memory slicing. +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +// Returns the shape per CTA, which is "physically" allocated. +// Such shapes may be bigger than the logical one due to, for example, padding +// in shared memory. +SmallVector getAllocationShapePerCTA(Attribute layout, + ArrayRef shape); +SmallVector getAllocationShapePerCTA(Type type); + +unsigned getNumCTAs(Attribute layout); + +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kContig +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kContig); + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs); + +// Dump information about which threads/registers contain each of the tensor +// elements. +void dumpLayout(RankedTensorType tensorType); + +// Dump the layout from HW point of view and prints what tensor element is held +// by each thread and register. +void dumpHWLayout(RankedTensorType tensorType); + +// Return a string representation of the layout of the tensor. +std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); + +// Return a string representation of the shared layout of the tensor. +std::string getSharedLayoutStr(LinearLayout &ll, bool useHWPointOfView); + +// Return a string representation of the distributed layout of the tensor. +std::string getDistributedLayoutStr(LinearLayout &ll, bool useHWPointOfView); + +template +llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s); + +llvm::SmallVector +expandMatrixOrderWithBatch(llvm::ArrayRef o); + +// Return true if the two layouts represent the exact same mapping. +bool areLayoutsEquivalent(ArrayRef shape, LayoutEncodingTrait lhs, + LayoutEncodingTrait rhs); + +// Return true if the innermost numElems are contiguous. +bool isInnermostContiguous(MemDescType type, unsigned numElems); + +LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy, + ArrayRef dstShape); + +FailureOr> +getTMABlockShape(ArrayRef shapePerCTA, int elementBitWidth, + int swizzleBytes, bool fp4Padded, bool isTransposed, + bool packedSize, function_ref emitError); +SmallVector getTMABlockShape(ArrayRef shapePerCTA, + int elementBitWidth, int swizzleBytes, + bool fp4Padded, bool isTransposed, + bool packedSize); + +// Verify the types of operations that operate on memory. +LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy, + ShapedType dstTy); +// Verify a memory allocation operation. +LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy); + +SetVector getPartitionIds(Operation *op); +SmallVector, 4> getPartitionOutputs(Operation *op); +SetVector getPartitionIds(OpOperand *use); +bool hasPartition(Operation *op); +bool hasWarpSpecializeTag(Operation *op); +std::optional getWarpSpecializeTag(Operation *op); + +} // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h new file mode 100644 index 0000000000..ded179b8af --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -0,0 +1,155 @@ +// Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to +// LinearLayout. + +#ifndef TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H +#define TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H + +#include + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton { +enum class ScaleDotElemType : uint32_t; +} // namespace mlir::triton + +namespace mlir::triton::gpu { +class SwizzledSharedEncodingAttr; +class NVMMASharedEncodingAttr; +class TensorOrMemDesc; +class MemDescType; +class CGAEncodingAttr; + +// - BlockedEncodingAttrs have the following input dimensions. +// +// "register": elements in one thread +// "lane": threads in a warp +// "warp": warps in a block/CTA +// "block": blocks in a cluster +// +// - An n-dimensional SwizzledSharedEncodingAttr has the following input +// dimensions. +// +// "offset": the n'th element in the allocation, within a particular thread +// block (i.e. within a CTA). The offset is measured in elements, not +// bytes. +// "block": blocks in a cluster +// +// All layouts have the following output dimensions. +// +// "dimi" for i in 0..n-1: the location in the n'th logical dimension of the +// output tensor. These also are not reordered according to the layout's +// `order`. +// +// You can flatten the input or output dimensions into a single dimension using +// LinearLayout::flattenIns/Outs(). +// +// elemBitWidth is the bit width of one element in the layout. This is required +// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e. +// shared layouts with nvmma_shared layout) but is otherwise unused. +LinearLayout toLinearLayout(RankedTensorType type); +LinearLayout toLinearLayout(MemDescType type); +LinearLayout toLinearLayout(TensorOrMemDesc type); +// UNSAFE OVERLOAD! +// If you call this with a SharedMemoryEncodingAttr, you should call it +// with the allocShape as the shape, otherwise the layout will be incorrect! +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout); + +// Convert the shared encoding of a tensor with `nvmma_shared` layout to a +// LinearLayout that maps from a linear shared memory offset to tensor index. +// +// If `disableSwizzle` is set, then the resulting layout does not include +// swizzling. +LinearLayout nvmmaSharedToLinearLayout(ArrayRef shape, + NVMMASharedEncodingAttr shared, + bool disableSwizzle = false); + +// Given a linear layout where the input dimensions contain a "block" dimension, +// this method sets the "block" dimension to 0 and removes the corresponding +// output dimensions. +// +// Note that this behavior differs from calling +// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in +// `inDimNames`. The latter does not modify the output sizes. +LinearLayout getLayoutWithinBlock(const LinearLayout &layout); + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of LinearLayoutConversions.cpp for why +// the variable with type CGAEncodingAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CGAEncodingAttr cgaLayoutAttr, + ArrayRef shape); + +LinearLayout chooseWmmaCTALinearLayout(MLIRContext *ctx, unsigned rank, + ArrayRef warpsPerCTA, + ArrayRef tilesPerWarp); +// In this function, we construct a linear layout representing the +// -> mapping +// for entire `src` and `dst` tensors. We determine the shape of the +// intermediate shared memory buffer needed for a register-to-register +// conversion using the maximum size accessed in each dimension from `src`'s +// layout and `dst`'s layout. See the getRepShapeForCvt function in +// Allocation.cpp for details. Note that the buffer might be smaller than the +// tensor being converted, so we need multiple "iterations" to move a subregion +// of the `src` tensor to the corresponding subregion of the `dst` tensor. The +// pesudo code of layout conversion is as follows: +// +// for iter in 0..numIterations: +// sync threads +// for vecIdx in [0..numRegisters/storeVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// store registers[vecIdx * storeVec, (vecIdx + 1) * storeVec)] to shared +// memory +// sync threads +// for vecIdx in [0..numRegisters/loadVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// load registers[vecIdx * loadVec, (vecIdx + 1) * loadVec)] from shared +// memory +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order); + +// The primary goal of this function is to efficiently load 2D tiles of a +// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs. +std::optional +chooseDsReadTrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth, unsigned instBitWidth, + unsigned numLanesInShuffleGroup); + +// Create LinearLayout for scale in scaled mfma. +LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef dotOperandShape, + unsigned mfmaMDim, + ArrayRef tilesPerWarp, + ArrayRef warpsPerCTA); + +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef dotOperandShape, + unsigned wmmaMDim, + LinearLayout ctaLayout); + +LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, + ArrayRef shape, int opIdx, + ArrayRef warpsPerCTA, + CGAEncodingAttr cgaLayout); + +// Create LinearLayout for nvidia mma tile. +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder); + +// Create a LinearLayout similar to mfmaLayout, but changing each thread to hold +// 8 elements. This layout is useful for emitting the widest 128-bit global +// store instructions. Since it closely resembles mfmaLayout, conversion between +// the two can be done using transferWithinWarp, without involving LDS +std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType); + +// Create the core layout (atom in the PTX manual) a given nvmma shared encoding +LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, + bool disableSwizzle); +} // namespace mlir::triton::gpu +#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Traits.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Traits.h new file mode 100644 index 0000000000..03ad522236 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Traits.h @@ -0,0 +1,34 @@ +#ifndef TRITONGPU_IR_TRAITS_H_ +#define TRITONGPU_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace OpTrait { + +template +class MemDescViewTrait + : public mlir::OpTrait::TraitBase { + // Optional: Add methods or verification logic here +}; + +template +class LocalLoadTrait + : public mlir::OpTrait::TraitBase { + // Optional: Add methods or verification logic here +}; + +template +class MemWaitOpTrait + : public mlir::OpTrait::TraitBase { + // Optional: Add methods or verification logic here +}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td new file mode 100644 index 0000000000..ffc7b7e5a4 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// Base definitions shared by TritonGPU attribute TableGen files. +// Splitting these out lets us emit certain attributes (e.g. CGAEncodingAttr) +// before interface headers without creating circular dependencies. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_ATTRBASE_TD +#define TRITONGPU_ATTRBASE_TD + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +// Traits used across several attrs. +def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; +def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; +def MemWaitOpTrait : NativeOpTrait<"MemWaitOpTrait">; + +// Common parameter helpers. +def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", + "linear layout"> { + let cppAccessorType = "const LinearLayout &"; +} + +// Base class for all TritonGPU attributes. +class TritonGPU_Attr traits = []> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + }]; +} + +#endif // TRITONGPU_ATTRBASE_TD diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 0000000000..7ad4ec84c0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1567 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td" + +//===----------------------------------------------------------------------===// +// Traits, Interfaces and shared Parameters +//===----------------------------------------------------------------------===// + +def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let description = [{ + Common trait for all TTGIR layouts. + }]; + let methods = [ + InterfaceMethod<"Get the CGA layout backing this encoding.", + "CGAEncodingAttr", "getCGALayout">, + InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", + (ins), [{}], [{ + return $_attr.getCGALayout().getRank(); + }]> + ]; +} +def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods< + LayoutEncodingTrait, ["getCGALayout"]>; + +def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ + Common trait describing shared memory. + }]; + let methods = [ + InterfaceMethod<"Return the default alignment for the layout.", + "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>, + ]; +} +def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods< + SharedEncodingTrait, ["getAlignment"]>; + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SwizzledSharedEncodingAttr + : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { + let mnemonic = "swizzled_shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different GPU threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #ttg.swizzled_shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CGAEncodingAttr":$CGALayout + ); + + let builders = [ + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CGALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SwizzledSharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin MFMA ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + return mfmaEnc.composeSharedLayoutForOperand( + CGALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(), + typeWidthInBit, needTrans); + } + + // ---- begin WMMA ---- + if (auto wmmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + return wmmaEnc.composeSharedLayoutForOperand( + CGALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(), + typeWidthInBit, needTrans); + } + + + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CGALayout); + + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { + return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CGALayout, typeWidthInBit, needTrans); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + // NVIDIA constructor! + // TODO(lezcano): We should totally get rid of all these constructors... + AttrBuilder<(ins "int":$opIdx, + "unsigned":$kWidth, + "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "unsigned":$bitwidth, + "bool":$needTrans), [{ + int K = getShapePerCTA(CGALayout.getCTASplitNum(), shape)[order[0]]; + // Elems necessary to cover all the banks divided by the inner dimension + // This packs a few rows together for small K + int perPhase = std::max(1024 / (bitwidth * K), 1); + + int mmaStride = 8; + int vec = 4 * kWidth; + // needsTrans is equiv. to flipping the opIdx + if (needTrans) + std::swap(vec, mmaStride); + assert(opIdx == 0 || opIdx == 1); + int rank = order.size(); + int kDim = opIdx == 0 ? rank-1 : rank-2; + if (order[0] != kDim) + std::swap(vec, mmaStride); + // Count how many vec elements are needed to cover all the banks + int maxPhase = std::max(std::min(mmaStride, 1024 / (vec * bitwidth)), 1); + // Account for the row packing from perPhase: mmaStride / perPhase + maxPhase = std::max(maxPhase / perPhase, 1); + return get(context, vec, perPhase, maxPhase, order, CGALayout); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CGALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CGALayout, bitwidth, needTrans); + }]>, + ]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def PaddedSharedEncodingAttr + : TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding", + [SharedEncodingTrait, DeclareLayoutEncodingMethods]> { + let mnemonic = "padded_shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different GPU threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. +Compared to SwizzledSharedEncodingAttr, this encoding combines padding with +element reordering via linear transformation (e.g. row permutation) to avoid +shared memory bank conflicts. + +Formally, given a layout: + padded_shared<[:+, :+, ...]> +We insert a padding of `` elements after every `` elements. +Multi interval-padding pairs are supported for flexibility of multi tiered +padding schemes; they compose in an additive manner. So for a 1-D tensor element +at index i, the corresponding shared memory location index is + i + \sum_{k} (i / interval_k) * pad_k = 1 +`` and `` all need to be power of two. + +Some concrete examples ignoring the linear component, using `eM` to mean tensor +elements and `pN` to mean padding: + +1. Single interval-padding pair: + + #ttg.padded_shared<[2:+2], {...}> + [e0, e1, p0, p1, + e2, e3, p2, p3, + ...] + +2. Double interval-padding pairs: + + #ttg.padded_shared<[2:+1, 4:+2], {...}> + [e0, e1, p0, + e2, e3, p1, p2, p3, + e4, e5, p4, + e6, e7, p5, p6, p7, + ...] + +Furthermore this encoding allows for a linear remapping from the 1-D shared +memory offset to logical n-D tensor elements. The remapping is given in the form +of linear bases mapping from offset to [dim0, dim1...dimN-1]. +See LinearLayout.h for more details how linear layouts are applied to remap +elements. +Some concrete examples using `xN` and `yN` to mean the logical n-D tensor elements +and `pN` to mean padding: + +1. 1D Single interval-padding with strided elements + + #ttg.padded_shared<[2:+2] {offset = [[2], [1]], block = []}> + [x0, x2, p0 p1, + x1, x3, p2, p3 + ...] + +2. 2D single interval-padding with rearranged rows. + + #ttg.padded_shared<[16:+1] {offset = [[0, 1], [0, 2], /*gap, stride by 2 rows*/[2, 0], [4, 0], [1, 0]]], block = []}> + [ + x0y0, x0y1, x0y2, x0y3, + x2y0, x2y1, x2y2, x2y3, + x4y0, x4y1, x4y2, x4y3, + x6y0, x6y1, x6y2, x6y3, + p0, + x1y0, x1y1, x1y2, x1y3, + x3y0, x3y1, x3y2, x3y3, + x5y0, x5y1, x5y2, x5y3, + x7y0, x7y1, x7y2, x7y3, + p1, + ] + +For identity mappings a short form based on order and shape is used to increase readability. The following two encodings are the same: + + #ttg.padded_shared<[2:+2] {order = [1, 0], shape = [16, 32]}> + #ttg.padded_shared<[2:+2] {offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, 0], [8, 0]], block = []}> + + + }]; + + let parameters = (ins + ArrayRefParameter<"unsigned">:$intervals, + ArrayRefParameter<"unsigned">:$paddings, + LinearLayoutParam:$linearComponent + ); + + let builders = [ + AttrBuilder<(ins "ArrayRef>":$intervalPads, + "LinearLayout":$linearComponent)>, + + // Builder to create an identity mapping as the linear component + AttrBuilder<(ins "ArrayRef>":$intervalPads, + "ArrayRef":$order, "ArrayRef":$shape, + "CGAEncodingAttr":$cgaLayout)>, + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + // Returns the order of the dimensions `dimName` of the layout. + // If more than dimension is of size one, it uses defaultOrder to determine + // the order of the dimensions of size one. + SmallVector orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const; + SmallVector getOrder() const; + + // Returns the bases of the dimensions `dimName` of the linear_component. + // If skipBroadcast is false, we count a base zero + SmallVector basesPerDim(StringAttr dimName, + bool skipBroadcast = true) const; + + unsigned getMinInterval() const { + return *llvm::min_element(getIntervals()); + } + + // Returns the total number of elements including padding given the input + // tensor shape. + int64_t getPaddedSize(ArrayRef shape) const; + }]; + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def SharedLinearEncodingAttr + : TritonGPU_Attr<"SharedLinearEncoding", "shared_linear_encoding", + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { + let mnemonic = "shared_linear"; + + let description = [{ + Linear shared encodings mirror LinearEncodingAttr but operate on shared + memory layouts. The LinearLayout parameter captures how shared memory + offsets (and optionally blocks) map to logical tensor indices. + }]; + + let parameters = (ins LinearLayoutParam:$linearLayout, "unsigned":$layoutAlignment); + + let extraClassDeclaration = [{ + SmallVector basesPerDim(StringAttr dimName, + bool skipBroadcast = true) const; + SmallVector orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const; + + SmallVector getOrder() const; + + unsigned getRank() const { return getLinearLayout().getNumOutDims(); } + + LinearLayout toLinearLayout(ArrayRef shape) const; + + int32_t getAlignment() const { return static_cast(getLayoutAlignment()); } + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", + [DeclareSharedEncodingMethods, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { + let mnemonic = "nvmma_shared"; + + let description = [{ + Represent blocked shared memory matching MMAv3/MMAv5 shared memory input. + This is meant to represent 2d tiled blocked layout. + The full layout representation is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout + When the memdesc has more than 2 dimensions the tiling is applied to 8 rows even if the first outer dimension is smaller than 8. + In this case `transposed` means that the contiguous dimension is the most outer dimension of the memdesc. + }]; + + + // fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs + // to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory + let parameters = ( + ins + "unsigned":$swizzlingByteWidth, + "bool":$transposed, + "unsigned":$elementBitWidth, + "bool":$fp4Padded, + "CGAEncodingAttr":$CGALayout + ); + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CGAEncodingAttr":$CGALayout, + "Type":$eltTy, + "bool": $fp4Padded), [{ + auto shapePerCTA = getShapePerCTA(CGALayout.getCTASplitNum(), shape); + int32_t swizzlingByteWidth = 0; + unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int packingFactor = fp4Padded ? 2 : 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * packingFactor * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + swizzlingByteWidth = 128; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + swizzlingByteWidth = 64; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + swizzlingByteWidth = 32; + } else { + swizzlingByteWidth = 0; + } + int flattenOutterDim = 1; + for (int i = 1; i < shapePerCTA.size(); i++) { + flattenOutterDim *= shapePerCTA[order[i]]; + } + if (shapePerCTA.size() < 2 || flattenOutterDim < 8) { + swizzlingByteWidth = 0; + } + bool transposed = order.size() > 1 && order[0] == 0; + return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CGALayout); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + int getPerPhase() const; + int getMaxPhase() const; + int getVec() const; + }]; + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def AMDRotatingSharedEncodingAttr : + TritonGPU_Attr<"AMDRotatingSharedEncoding", "amd_rotating_shared_encoding", + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { + let mnemonic = "amd_rotating_shared"; + + let description = [{ +This shared encoding is similar to SwizzledSharedEncodingAttr, but instead of +repeating swizzling pattern every `maxPhase*perPhase` rows of the memory object, +called a block, this layout changes swizzling pattern `maxPhase` times, then +repeats the pattern. The name "rotating" comes from the fact that first tensor +element of each block is swizzled with different phase, which is equal to +current block number: 0, 1, 2.. maxPhase-1, 0, 1, 2 ... + +This layout is used to reduce bank conflicts in cases where shared memory writes +and reads are performed on layouts with different order. It's meant for hardware +without native shared memory tranpose support. + +Swizzling pattern affects only 2 fastest dimensions of a tensor. +In the following text these two dimensions are called row and column: +- row is a fastest dimension +- column is a second fastest dimension + +Elements in a row dimension are stored in memory contiguously. + +If a matrix of size [128x64] is stored in this shared layout with order [1, 0], +dim 1 (64) will be stored contiguously and called row, dim 0 (128) is will be +called column. If order of shared layout is [0, 1], dim 0 (128) is stored +contiguously becomes a row, dim 1 (64) becomes a column. + +Swizzling pattern is following: + +Let's consider an element with logical coordinates = (inRowId, inColId). +For simplicity, we do not vectorize memory in examples, +i.e. vec == 1 and layout swizzles inidividual elements. +For vec != 1 example, take a look at SwizzledSharedEncodingAttr documentation. + +Swizzled coordinates within memory object are (outRowId, outColId): + + outRowId = inRowId + phase = (inRowId / perPhase) % maxPhase + blockNo = (inRowId / (perPhase * maxPhase)) % maxPhase + combinedPhase = phase ^ blockNo + outColId = inColId ^ combinedPhase + +Actual offset in memory could be computed with following function: + +memmory_offset = (outColId + outRowId * num_of_element_in_row) * sizeof(element) + + +Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1): + + #shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1) + 2 [ 9, 8, 11, 10], // phase = 0 blockNo = 1 (xor with 1) + 3 [12, 13, 14, 15] // phase = 1 blockNo = 1 (xor with 0) + 4 [16, 17, 18, 19], // phase = 0 blockNo = 0 (xor with 0) + 5 [21, 20, 23, 22], // phase = 1 blockNo = 0 (xor with 1) + 6 [25, 24, 27, 26], // phase = 0 blockNo = 1 (xor with 1) + 7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0) + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 4, 5, 6, 7], // phase = 0 blockNo = 0 (xor with 0) + 2 [ 9, 8, 11, 10], // phase = 1 blockNo = 0 (xor with 1) + 3 [13, 12, 15, 14] // phase = 1 blockNo = 0 (xor with 1) + 4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1) + 5 [21, 20, 23, 22], // phase = 0 blockNo = 1 (xor with 1) + 6 [24, 25, 26, 27], // phase = 1 blockNo = 1 (xor with 0) + 7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0) + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1) + 2 [10, 11, 8, 9], // phase = 2 blockNo = 0 (xor with 2) + 3 [15, 14, 13, 12] // phase = 3 blockNo = 0 (xor with 3) + 4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1) + 5 [20, 21, 22, 23], // phase = 1 blockNo = 1 (xor with 0) + 6 [27, 26, 25, 24], // phase = 2 blockNo = 1 (xor with 3) + 7 [30, 31, 28, 29] // phase = 3 blockNo = 1 (xor with 2) + }]; + + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CGAEncodingAttr":$CGALayout + ); + + let hasCustomAssemblyFormat = 1; +} + + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// + +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". + }]; + + let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape); + }]>, + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getElemsPerThread(shape); + }]>, + InterfaceMethod<"Convert to LinearLayout.", + "LinearLayout", + "toLinearLayout", + (ins "ArrayRef":$shape)>, + ]; +} + +class DistributedEncoding traits = []> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + // Implemented in subclasses + SmallVector getRepOrder() const; + + LinearLayout toLinearLayout(ArrayRef shape) const; + }]; +} + +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins LinearLayoutParam:$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + // Generic distributed encoding methods + unsigned getTotalElemsPerThread(ArrayRef shape) const; + SmallVector getElemsPerThread(ArrayRef shape) const; + + SmallVector getContig(const char *, SmallVector) const; + SmallVector getContigPerThread() const; + SmallVector getContigPerWarp() const; + SmallVector getOrder() const; + SmallVector getWarpOrder() const; + SmallVector getThreadOrder() const; + + + // Generalizes get{Warp,Thread,CTA}Order to linear layouts. + // Returns the order of the dimensions `dimName` of the layout. + // If more than dimension is of size one, it uses defaultOrder to determine + // the order of the dimensions of size one. + SmallVector orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const; + + // Generalizes getThreadsPerWarp, getWarpsPerCTA, getCTAsPerCGA to linear layouts. + // Returns the bases of the dimensions `dimName` of the layout. + // If skipBroadcast is false, we count a base zero + SmallVector basesPerDim(StringAttr dimName, + bool skipBroadcast = true) const; + SmallVector getThreadsPerWarp() const; + SmallVector getWarpsPerCTA() const; + + unsigned getRank() const { return getLinearLayout().getNumOutDims(); } + + // [FIXME LL] Supports legacy behaviour. We should remove these functions + SmallVector getSizePerThread() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + blocked = {{0, 1}} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + blocked = {{0, 1}} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + blocked = {{0, 1}, {1, 0}} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread, + ArrayRefParameter<"unsigned">:$threadsPerWarp, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CGALayout is optional in the textual IR. If omitted, we infer it to be a + // CGA with a single CTA (i.e. the trivial map onto dim0..dimn-1) + "CGAEncodingAttr":$CGALayout + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CGAEncodingAttr":$CGALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CGALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, std::max(1, shapePerCTA[i] / sizePerThread[i])); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CGALayout); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, std::max(1, shape[i] / sizePerThread[i])); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CGAEncodingAttr CGALayout = CGAEncodingAttr::fromSplitParams(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CGALayout); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// + +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `version`: The GPU architecture: + - 1: gfx908: CDNA1 + - 2: gfx90a: CDNA2 + - 3: gfx942: CDNA3 + - 4: gfx950: CDNA4 +- `warpsPerCTA`: The warp layout in the block. +- `instrShape`: The shape in the form of (M, N, K) of the matrix. +- `isTransposed`: Indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). +- `tilesPerWarp`: The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions. +- `elementBitWidth`: Bit width of the output element type. Supported values are 32 and 64. Defaults to 32. + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> warp 0 warp 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + warp 1 warp 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + +Example 4: +This example demonstrates semantics of tilesPerWarp parameter. The MFMA layout (with tilesPerWarp=[1,1]) +assumes that each warp within a CTA tile computes a single MFMA tile. When the tensor is larger than +a single CTA tile, these tiles are repeated across the tensor. In this setup, the output tiles computed +by each warp were strided by the number of warps per CTA tile in both row and column dimensions. + +For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA +tiles looked like: + +w0 w1 w0 w1 +w2 w3 w2 w3 +w0 w1 w0 w1 +w2 w3 w2 w3 + +tilesPerWarp parameter allows each warp to compute contiguous MFMA tiles in the row and/or column dimensions. +Using the same example with tilesPerWarp = [2, 2], the layout becomes: + +w0 w0 w1 w1 +w0 w0 w1 w1 +w2 w2 w3 w3 +w2 w2 w3 w3 +}]; + + let parameters = ( + ins + "unsigned": $version, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + ArrayRefParameter<"unsigned">:$instrShape, + "bool":$isTransposed, + "CGAEncodingAttr":$CGALayout, + ArrayRefParameter<"unsigned">:$tilesPerWarp, + "unsigned":$elementBitWidth + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$version, + "ArrayRef":$warpsPerCTA, + "ArrayRef":$instrShape, + "bool":$isTransposed, + "CGAEncodingAttr":$CGALayout, + CArg<"ArrayRef", "{}">:$tpw, + CArg<"unsigned", "0">:$elementBitWidth), [{ + SmallVector tilesPerWarp(tpw); + if (tilesPerWarp.empty()) + tilesPerWarp = SmallVector(warpsPerCTA.size(), 1); + if (elementBitWidth == 0) + elementBitWidth = 32; + return $_get($_ctxt, version, warpsPerCTA, instrShape, isTransposed, CGALayout, tilesPerWarp, elementBitWidth); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + + // Check if tilesPerWarp is 1 in every dimension. + bool hasUnitTilesPerWarp() const; + + // Returns a swizzled shared layout matching this MFMA layout for the + // dot operand at the given |operandIdx| with |operandShape|. + SwizzledSharedEncodingAttr composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned vectorSize, + unsigned elemBitWidth, bool needTrans) const; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An encoding for tensors that have been produced by WMMA matrix core instructions, +available on AMD Radeon GPUs of RDNA architectures. + +It is characterized by the following parameters: +- `version` indicates the GPU architecture: + - 1: RDNA3; e.g., gfx1100, gfx1101 + - 2: RDNA4; e.g., gfx1200, gfx1201 + - 3: gfx1250 +- `ctaLayout` indicates the warp layout in the block. This is a generalization + compared to previous warp layout representation using warpsPerCTA and tilesPerWarp + parameters. +- `instrShape` indicates the shape in the form of (M, N, K) of the matrix + operation performed by a single WMMA instruction. Defaults to (16, 16, 16). +- `isTransposed` indicates the layout of the result tensor is transposed. + +Example 1: +Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2]. +Matrix elements represent which lane owns the element. Currently only wave32 mode +is supported. + +// ----------------------------------- version = 1 ----------------------------------- // + +Row | warp 0 warp 1 + |/-------------------^-------------------\ /-------------------^-------------------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +2 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +3 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +14 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + | warp 2 warp 3 +16 |/-------------------^-------------------\ /-------------------^-------------------\ +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +18 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +19 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +20 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +30 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2/3, isTransposed = false ------------------------ // + +Row | warp 0 warp 1 + |/--------^---------\ /---------^--------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +6 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +7 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +8 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +9 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +14 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + | + | warp 2 warp 3 + |/--------^---------\ /---------^--------\ +16 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +22 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +23 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +24 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +25 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +30 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2/3, isTransposed = true ------------------------ // + + | warp 0 warp 1 + |/----------------^----------------\ /-------^-------\ +Col>| 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 ... 32 +Row | +0 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +1 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +14 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +15 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + | + | warp 2 warp 3 + |/----------------^----------------\ /-------^-------\ +16 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +17 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + +Example 2: +This example illustrates the purpose of the ctaLayout parameter. +ctaLayout is a linear layout describing how warps are arranged across WMMA tiles. +Previously, this information was encoded using warpsPerCTA and tilesPerWarp parametes. +For instance, a configuration with 4 warps, represented as: + +warpsPerCTA = [2, 2], tilesPerWarp = [1, 1] + +would translate to: + +ctaLayout = {reg = [], warp = [[0, 1], [1, 0]]} + +By default, WMMA assumes that each warp in a CTA computes exactly one WMMA tile. +In the grid below, each w* label indicates which warp computes that tile: + +w0 w1 w0 w1 +w2 w3 w2 w3 +w0 w1 w0 w1 +w2 w3 w2 w3 + +To express more complex layouts, we must also account for repetitions within the mapping. +For example, the configuration formerly described as: + +warpsPerCTA = [2, 2], tilesPerWarp = [2, 2] + +would translate to: + +ctaLayout = {reg = [[0, 1], [1, 0]], warps = [[0, 2], [2, 0]] } + +w0 w0 w1 w1 +w0 w0 w1 w1 +w2 w2 w3 w3 +w2 w2 w3 w3 + +This parameter provides a more general way to define warp mappings than what +warpsPerCTA and tilesPerWarp alone could express. +For instance: + +ctaLayout = {reg = [[1, 0], [0, 1]], warps = [[0, 2], [2, 0]]} + +still represents a layout similar to: + +warpsPerCTA = [2, 2], tilesPerWarp = [2, 2] + +but with a different ordering of repetitions. + +The motivation for this broader formulation comes from the need to describe swizzled warp +layouts, which help avoid LDS partition conflicts on architectures such as gfx1250. +A valid example of such swizzled configuration is: + +ctaLayout = {reg = [[2, 0]], warps = [[2, 1], [1, 0]]} + +With corresponding mapping: + +w0 w1 <- second tile computed by w1 +w2 w3 +w0 w1 <- first tile computed by w1 +w2 w3 + +Note that ctaLayout naturally composes with layout definied on a single WMMA tile +to form final WMMA layout. + +wmmaLayout = tileLayout * ctaLayout + +This simplifies both WMMA and dotOperand layouts lowering to linear layout. + }]; + + let parameters = ( + ins + "unsigned": $version, + LinearLayoutParam:$ctaLayout, + "bool":$isTransposed, + "CGAEncodingAttr":$CGALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getRepOrderForOperand(int opIdx) const; + LinearLayout getTileLayout(unsigned rank) const; + static SmallVector getDefaultInstrShape() { + return {16, 16, 16}; + } + + // Returns a swizzled shared layout matching this WMMA layout for the + // dot operand at the given |operandIdx| with |operandShape|. + SwizzledSharedEncodingAttr composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, + unsigned elemBitWidth, bool needTrans) const; + }]; +} + +def MUSAWmmaEncodingAttr : DistributedEncoding<"MUSAWmmaEncoding", "musa_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "musa_wmma"; + + let description = [{ +An encoding for tensors produced by MUSA WMMA instructions. + +It is characterized by: +- `versionMajor`/`versionMinor` identifying the MMA generation. +- `warpsPerCTA` describing how warps are arranged within the CTA. +- `instrShape` describing the (M, N, K) shape of a single MMA instruction. +- `CGALayout` describing CTA clustering. + +PH1 (capability=31) operates in wave32 mode only. + }]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CGAEncodingAttr":$CGALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isPH1() const; + SmallVector getRepOrderForOperand(int opIdx) const; + SwizzledSharedEncodingAttr composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, + unsigned elemBitWidth, bool needTrans) const; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def MUSASqmmaEncodingAttr : DistributedEncoding<"MUSASqmmaEncoding", "musa_sqmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "musa_sqmma"; + + let description = [{ +An encoding for tensors produced by MUSA SQMMA (squad-level MMA) instructions. + +It is characterized by: +- `versionMajor`/`versionMinor` identifying the MMA generation. +- `warpsPerCTA` describing how warps are arranged within the CTA. +- `instrShape` describing the (M, N, K) shape of a single MMA instruction. +- `CGALayout` describing CTA clustering. + +PH1 SQMMA executes in wave32 mode and uses 4 warps per squad. + }]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CGAEncodingAttr":$CGALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isPH1() const; + SmallVector getRepOrderForOperand(int opIdx) const; + SmallVector getElemsPerThread(ArrayRef shape) const; + unsigned getTotalElemsPerThread(ArrayRef shape) const; + SwizzledSharedEncodingAttr composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, + unsigned elemBitWidth, bool needTrans) const; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CGAEncodingAttr":$CGALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int kWidth, + int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + "DistributedEncodingTrait":$parent + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + }]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); + // For MMAV2 and V3 + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned kWidth = std::max(32 / bitwidth, 1u); + return $_get(context, opIdx, parent, kWidth); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration; +} + +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td new file mode 100644 index 0000000000..314cfd53e8 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// Aggregated attr definitions (including CGA) for implementation emission. +// This file exists to generate AttrDefs.cpp.inc once, without duplicating +// CGAEncodingAttr while still making CGA available before LayoutEncodingTrait. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_ATTRIMPLS_TD +#define TRITONGPU_ATTRIMPLS_TD + +include "triton/Dialect/TritonGPU/IR/CGAEncodingAttr.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" + +#endif // TRITONGPU_ATTRIMPLS_TD diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td new file mode 100644 index 0000000000..3169dc451f --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -0,0 +1,41 @@ +#ifndef TRITONGPU_DIALECT +#define TRITONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGPU_Dialect : Dialect { + let name = "ttg"; + + let cppNamespace = "::mlir::triton::gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "mlir::gpu::GPUDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + + LinearLayout toLinearLayout(ArrayRef shape, Attribute layout); + LinearEncodingAttr toLinearEncoding(ArrayRef shape, Attribute layout); + + static int getNumCTAs(ModuleOp mod); + static int getThreadsPerWarp(ModuleOp mod); + + private: + LinearLayoutCache llCache; + LinearEncodingCache leCache; + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUEnums.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUEnums.td new file mode 100644 index 0000000000..4bb47c0b5b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUEnums.td @@ -0,0 +1,22 @@ +#ifndef TRITONGPU_ENUMS +#define TRITONGPU_ENUMS + +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +// Bitmask enum describing which memory domains a barrier/fence orders. +def TTG_AddrSpace : I32BitEnumAttr< + "AddrSpace", "", + [ + I32BitEnumAttrCase<"None", 0b0000, "none">, + I32BitEnumAttrCase<"Local", 0b0001, "local">, + I32BitEnumAttrCase<"GlobalRead", 0b0010, "global_read">, + I32BitEnumAttrCase<"GlobalWrite", 0b0100, "global_write">, + I32BitEnumAttrCase<"TensorRead", 0b1000, "tensor_read">, + I32BitEnumAttrCase<"TensorWrite", 0b10000, "tensor_write">, + I32BitEnumAttrCase<"All", 0b11111, "all"> + ]> { + let cppNamespace = "::mlir::triton::gpu"; +} + +#endif // TRITONGPU_ENUMS diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h new file mode 100644 index 0000000000..490816334c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -0,0 +1,13 @@ +#ifndef TRITON_GPU_DIALECT_INTERFACES_H +#define TRITON_GPU_DIALECT_INTERFACES_H + +#include "mlir/IR/OpDefinition.h" +#include "triton/Dialect/TritonGPU/IR/CGAEncodingAttr.h" + +// clang-format off +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc" +#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc" +// clang-format on + +#endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td new file mode 100644 index 0000000000..3862b7f474 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td @@ -0,0 +1,29 @@ +#ifndef TRITONGPU_OP_INTERFACES +#define TRITONGPU_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> { + let description = [{ + This interface is for operations that upcast floating-point numbers. + }]; + + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Infer destination encoding", + /*retType=*/"mlir::Attribute", + /*methodName=*/"inferDstEncoding", + /*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$srcEnc) + >, + InterfaceMethod< + /*desc=*/"Infer operand encoding from dst encoding", + /*retType=*/"mlir::Attribute", + /*methodName=*/"inferSrcEncoding", + /*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$dstEnc) + > + ]; +} + +#endif // TRITONGPU_OP_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td new file mode 100644 index 0000000000..5f666b43a2 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -0,0 +1,741 @@ +#ifndef TRITONGPU_OPS +#define TRITONGPU_OPS + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUEnums.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" // RegionBranchOpInterface +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ViewLikeInterface.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTG_Op traits = []> : + Op { +} + +def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTG_AsyncWaitOp : TTG_Op<"async_wait", [MemWaitOpTrait]> { + let summary = "Ensure all specified async_copy_* operations are complete."; + let description = [{ + The `async_wait` op waits until at most "num" async copy groups are outstanding without synchronising CTA execution. + It takes zero or more `asyncToken` plus an integer `num` that specifies how many async copy groups can remain + outstanding after the `async_wait` op is completed. `num = 0` waits until all groups of async copies are complete. + + This operation does not provide any syncronisation in the CTA, if syncronisation is needed use `ttg.local_barrier` + in addition to this operation. + }]; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "($asyncToken^)? attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + let summary = "Commit pending async copies into an async group that can be waited on"; + let description = [{ + Closes the current batch of async_copy_* operations + and allows for them to be waited on with `ttg.async_wait`. + This is required in order to ensure async copy operations can be waited on. + }]; + let results = (outs TTG_AsyncToken:$asyncToken); + let arguments = (ins Variadic:$inputTokens); + + let assemblyFormat = "(`tokens` $inputTokens^)? attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ + AttrSizedOperandSegments, + OptionalTypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)">, + OptionalTypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)">, +]> { + let summary = "Copy data from global memory to local memory asynchronously"; + + let hasVerifier = 1; + let description = [{ + This operation copies data from global memory to local memory asynchronously. + This is analogue to `tt.load` except the data are copied to local memory pointed + to by the memory descriptor instead of a distributed tensor. The rest of the + operands are the same as `tt.load`. + Contiguity is the maximum number of elements that can be loaded in a single vector with + the given layout and mask. + This allows op to use `async_copy_global_to_local` even if the alignment cannot be proven based on IR. + + The data will only be available in local memory after `ttg.async_wait` is issued to wait on the + completion of `async_copy_global_to_local`. The async copy operations must be committed using + `ttg.async_commit_group` to close the batch and allow for them to be waited on. + }]; + + let arguments = (ins + Arg]>:$src, + Arg]>:$result, + Optional:$mask, + Optional:$other, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile, + DefaultValuedAttr:$contiguity + ); + + let results = (outs TTG_AsyncToken:$token); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between other, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $src `,` $result (`mask` $mask^)? (`other` $other^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($src) `->` type($result) + }]; +} + +// Allocate shared memory +def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + + The `src` operand is an optional initializer for the allocated buffer. It + must have the element type as the buffer. If `src` is not specified, the + returned buffer must be mutable. + }]; + let arguments = ( + ins + Optional:$src, + OptionalAttr:$alignment + ); + + let builders = [ + OpBuilder<(ins "Type":$result), + [{ build($_builder, $_state, result, Value(), IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src), + [{ build($_builder, $_state, result, src, IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src, "int32_t":$alignment), + [{ build($_builder, $_state, result, src, $_builder.getI32IntegerAttr(alignment)); }]> + ]; + + let extraClassDeclaration = [{ + bool isSharedMemoryAlloc() { + return isa_and_nonnull(getType().getMemorySpace()); + } + int32_t getAlignmentOrDefault(); + }]; + let assemblyFormat = [{ + ($src^)? attr-dict `:` functional-type(operands, results) + }]; + + let results = (outs TTG_MemDescType:$result); + let hasFolder = 1; + let hasVerifier = 1; +} + +// Deallocate shared memory +def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as ttng.warp_group_dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins Arg]>:$src); + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} +def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor pointing to the `i`-th element of the + input descriptor along the 0-th dimension. + + It doesn't affect the underlying memory. + + For example, suppose that + - the input shape is 2x4x16xf16, + - the output shape is 4x16xf16, and + - index = 1. + Then the output descriptor is equivalent to input[1], where input is the logical tensor. + }]; + + let arguments = (ins TTG_MemDescType:$src, I32:$index); + + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = [{$src `[` $index `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let hasVerifier = 1; +} + +def TTG_MemDescSubsliceOp : TTG_Op<"memdesc_subslice", [Pure, MemDescViewTrait]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the logical tensor. + It doesn't affect the underlying memory. + + For example, suppose that + - the input shape is 32x16xf16, + - the output shape is 8x16xf16, and + - offsets = [2, 1]. + Then in Python syntax, the subview covers input[2:8+2, 1:16+1] where input is + the logical tensor. + + The offsets must be larger or equal to the tile of the tensor (or zero). + }]; + let arguments = (ins TTG_MemDescType:$src, DenseI32ArrayAttr:$offsets); + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + // Render offsets inline as %src[0, 0] via a custom directive, but keep + // the overall parse/print generated from this assemblyFormat. + let assemblyFormat = [{ + $src `[` custom($offsets) `]` attr-dict `:` qualified(type($src)) + `->` qualified(type($result)) + }]; + + let results = (outs TTG_MemDescType:$result); + + let hasFolder = 1; + let hasVerifier = 1; +} + +def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, + MemDescViewTrait, + TransposeOpInterface, + InferTypeOpWithLayoutEquivalence, + SameOperandsAndResultElementType]> { + let summary = "transpose the descriptor"; + + let description = [{ + This operation returns a new descriptor + representing a transposed view of the buffer. + }]; + + let arguments = ( + ins TTG_MemDescType:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))"; + + let hasFolder = 1; +} + +def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure, + MemDescViewTrait, + SameOperandsAndResultElementType]> { + let summary = "creates a descriptor for the new shape"; + + let description = [{ + This operation returns a new descriptor representing a reshaped view of the underlying buffer. + This doesn't affect the memory. + }]; + + let arguments = (ins TTG_MemDescType:$src); + + let builders = [ + OpBuilder<(ins "Value":$src, "ArrayRef":$shape), + [{ + MemDescType dstTy; + auto srcTy = cast(src.getType()); + auto result = inferReturnTypes($_builder.getContext(), + $_builder.getUnknownLoc(), + srcTy, shape, dstTy); + assert(succeeded(result) && "failed to infer return types"); + build($_builder, $_state, dstTy, src); + }]> + ]; + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypes(MLIRContext *context, + std::optional loc, + MemDescType srcTy, + ArrayRef dstShape, + MemDescType &inferredReturnType); + }]; + + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))"; + + let hasVerifier = 1; +} + +def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewTrait]> { + let summary = "reinterpret a memory descriptor as a different type and shape"; + + let description = [{ + The `ttg.memdesc_reinterpret` operation reinterprets a memory descriptor + as one with a different shape and element type. Because memory descriptors + lack strides, this operation is only valid if the original memory descriptor + is contiguous. + }]; + + let arguments = (ins TTG_MemDescType:$src); + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` qualified(type($result)) + }]; + + let hasFolder = 1; +} + +def TTG_LocalLoadOp : TTG_Op<"local_load", [LocalLoadTrait]> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins + Arg]>:$src, + Optional:$token + ); + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + let hasVerifier = 1; +} + +def TTG_LocalStoreOp : TTG_Op<"local_store"> { + let summary = "Store a distributed tensor into a buffer in local memory"; + + let description = [{ + Store a distributed tensor into a buffer in local memory. + }]; + let arguments = (ins + TT_Tensor:$src, + Arg]>:$dst + ); + + let hasVerifier = 1; + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) + }]; +} + +def TTG_LocalGatherOp : TTG_Op<"local_gather", [LocalLoadTrait]> { + let summary = "Gather elements from shared memory along a specified axis"; + + let description = [{ + Gather elements from a shared memory descriptor using an indices tensor along a + single specified axis. The output tensor has the same shape as the indices tensor. + + For each output position I, the operation reads from src where the coordinate at + the gather axis is replaced by indices[I]: + result[I] = src[I[0], ..., indices[I], ..., I[n]] + where the axis dimension is replaced by the index value. + + This matches the behavior of tt.gather but operates on shared memory descriptors. + }]; + let arguments = (ins + Arg]>:$src, + TT_IntTensor:$indices, + I32Attr:$axis, + Optional:$token + ); + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src, "Value":$indices, "IntegerAttr":$axis), + [{ + build($_builder, $_state, retType, src, indices, axis, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src `[` $indices `]` (`token` $token^)? attr-dict `:` qualified(type($src)) `,` type($indices) `->` type($result)}]; + let hasVerifier = 1; +} + +def TTG_LocalScatterOp : TTG_Op<"local_scatter"> { + let summary = "Scatter elements to shared memory along a specified axis"; + + let description = [{ + Scatter elements to a shared memory descriptor using an indices tensor along a + single specified axis. The values tensor has the same shape as the indices tensor. + + For each input position I, the operation writes to dst where the coordinate at + the scatter axis is replaced by indices[I]: + dst[I[0], ..., indices[I], ..., I[n]] = values[I] + where the axis dimension is replaced by the index value. + + This is the inverse of local_gather and writes to shared memory at runtime-computed indices. + }]; + let arguments = (ins + Arg]>:$dst, + TT_Tensor:$values, + TT_IntTensor:$indices, + I32Attr:$axis, + Optional:$token + ); + + let builders = [ + OpBuilder<(ins "Value":$dst, "Value":$values, "Value":$indices, "IntegerAttr":$axis), + [{ + build($_builder, $_state, dst, values, indices, axis, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$dst `[` $indices `]` `,` $values (`token` $token^)? attr-dict `:` qualified(type($dst)) `,` type($indices) `,` type($values)}]; + let hasVerifier = 1; +} + +def TTG_PredicateStageOp: TTG_Op<"predicate_stage", + [Pure, AllTypesMatch<["iv", "ub", "step"]>]> { + let summary = "pipeliner stage predicate"; + let arguments = (ins AnySignlessIntegerOrIndex:$iv, + AnySignlessIntegerOrIndex:$ub, + AnySignlessIntegerOrIndex:$step, + I32Attr:$maxStage, + I32Attr:$stage); + let results = (outs I1:$result); + let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)"; +} + +def TTG_MaskOp: TTG_Op<"mask", + [SingleBlock]> { + let summary = "mask op for pipelining"; + let arguments = (ins I1:$pred); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$region); +} + +def TTG_MaskReturnOp: TTG_Op<"mask.return", + [HasParent<"MaskOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for mask operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> { + let summary = "Upcast fp4 (e2m1) to fp"; + + let hasVerifier = 1; + + let description = [{ + Upcast fp4 (e2m1) represented packed as i8s to fp. + + The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits + the second fp4 element. + + The `axis` attribute specifies the axis along which the fp4 elements are packed. + }]; + + let builders = [ + OpBuilder<(ins "TypedValue":$src, "Type":$elemType, "int32_t":$axis)> + ]; + + let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis); + let results = (outs TT_FloatTensor:$result); + + let extraClassDeclaration = [{ + static LogicalResult verifyFp4ToFp( + mlir::Operation *op, + RankedTensorType srcTy, + RankedTensorType resTy, + unsigned axis); + }]; + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// Allocate global memory +def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc"> { + let summary = "allocate a global memory buffer"; + let description = [{ + This operation allocates a buffer in global memory that is private to the current program. + The `backend` attribute specifies the backend to use for allocation. + The `default` backend is used by TritonGPU passes. + Downstream Triton tools and compilers can register a different backend and use a different allocation policy. + }]; + let arguments = ( + ins + I32Attr:$nbytes, + I32Attr:$alignment, + DefaultValuedAttr:$backend + ); + let results = (outs Arg]>:$result); + + let assemblyFormat = [{attr-dict `:` qualified(type($result))}]; +} + +def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [ + RecursiveMemoryEffects, RecursivelySpeculatable, AsyncRegions, + DeclareOpInterfaceMethods +]> { + let summary = "asynchronously execute code on multiple warpgroups"; + let description = [{ + The `ttg.warp_specialize` op represents executing different code + simultaneously on different warp groups. A warp group is a group of + power-of-2 warps, which can be a different number of warps than in the + enclosing region. + + The "default" region of the op represents the code executed by the currently + executing warp group. This region is allowed to implicitly capture. The op + contains a number of "partition" regions that are isolated from above. They + must be isolated because these regions represent different layout domains, + as the number of warps is different. + + Semantically, execution of each region starts simultaneously for each warp + group, and all warp groups are joined at the end of the op. + + Example: + + ```mlir + %0 = ttg.warp_specialize(%a, %b) + default { + %out = some_operation(%a) // implicit capture of `%a` + ttg.warp_yield %out : i32 + } + partition0(%arg0: i32, %arg1: i32) num_warps(8) { + some_async_dispatch(%arg0, %arg1) + ttg.warp_return + } + partition1(%arg0: i32, %arg1: i32) num_warps(1) { + some_async_dispatch(%arg0, %arg1) + ttg.warp_return + } : (i32, i32) -> i32 + ``` + }]; + + let arguments = (ins DenseI32ArrayAttr:$partitionNumWarps, + OptionalAttr:$warpGroupStartIds, + OptionalAttr:$requestedRegisters, + OptionalAttr:$actualRegisters); + let results = (outs Variadic:$defaultPassthrough); + + let regions = (region + MinSizedRegion<1>:$defaultRegion, + SizedRegion<1>:$partitionOpHolder + ); + + let extraClassDeclaration = [{ + RegionRange getPartitionRegions(); + WarpSpecializePartitionsOp getPartitionOp(); + + // Get the size and alignment of the capture list. + std::pair getCaptureSizeAlign(); + // Get the total number of extra warps required. + unsigned getTotalPartitionWarps(); + }]; + + let builders = [OpBuilder<(ins "TypeRange":$resultTypes, + "ArrayRef":$partitionNumWarps, + "unsigned":$numPartitionRegions)>, + OpBuilder<(ins "TypeRange":$resultTypes, + "ArrayRef":$partitionNumWarps)>, + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + let hasCanonicalizeMethod = 1; +} + +def TTG_WarpSpecializePartitionsOp + : TTG_Op<"warp_specialize.partitions", + [IsolatedFromAbove, RecursiveMemoryEffects, + RecursivelySpeculatable, Terminator, + HasParent<"WarpSpecializeOp">, + DeclareOpInterfaceMethods< + RegionBranchOpInterface, ["getEntrySuccessorOperands"]>]> { + let summary = "container op for `ttg.warp_specialize`"; + let description = [{ + Because MLIR requires entire operations be isolated from above, this op + contains the actual isolated from above regions of `ttg.warp_specialize`. + }]; + + let arguments = (ins Variadic:$explicitCaptures); + let regions = (region VariadicRegion>:$partitionRegions); + + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TTG_WarpYieldOp : TTG_Op<"warp_yield", [ + Pure, Terminator, ReturnLike, HasParent<"WarpSpecializeOp">, + DeclareOpInterfaceMethods +]> { + let summary = "yield from the default region of `ttg.warp_specialize`"; + let description = [{ + The `ttg.warp_yield` operation is the terminator for the "default" region of + a `ttg.warp_specialize` operation. The operands are passed transparently as + the SSA results of the `ttg.warp_specialize` operation. + + Example: + + ```mlir + ttg.warp_yield %a, %b : i32, tensor<32xbf16, #blocked> + ``` + }]; + + let arguments = (ins Variadic:$values); + + let assemblyFormat = "($values^)? attr-dict (`:` type($values)^)?"; + let hasVerifier = 1; +} + +def TTG_WarpReturnOp : TTG_Op<"warp_return", [ + Pure, Terminator, ReturnLike, HasParent<"WarpSpecializePartitionsOp"> +]> { + let summary = "implicit terminator from partition regions"; + let description = [{ + The `ttg.warp_return` operation is the implicit terminator that ends the + partition regions of a `ttg.warp_specialize` op. It has no operands as these + regions cannot return anything. + + TODO: Support returning uniform values from partition regions. + }]; + + let assemblyFormat = "attr-dict"; +} + +def TTG_BarrierOp : TTG_Op<"barrier"> { + let summary = "Synchronizes execution and reads/writes to the selected address spaces for all threads in the CTA."; + let description = [{ + The `barrier` op synchronises the execution and all operations between the selected address spaces for all + threads in the CTA. It is used to coordinate communication between threads in the CTA. + + This operation waits until all threads in the CTA have reached a `barrier` (for syncronisation) and operations + between the selected address spaces made by these threads prior to the op are visible to all threads in the CTA. + + Data hazards between threads accessing the same memory can be avoided by synchronising the + specified scope in-between these accesses with a `barrier`. + + A `barrier` operation only provides syncronisation and memory guarantees on the selected address spaces in the CTA. + + The mandatory `addrspace` attribute is a bitmask describing which address spaces will be visible when the `barrier` completes: + + * `none` control-only syncronisation (no memory ordering). + * `local` shared-memory operations are complete and visible CTA-wide. + * `global_read` global memory reads are complete and visible CTA-wide. + * `global_write` global memory writes are complete and visible CTA-wide. + * `tensor_read` tensor memory read operations are complete and visible CTA-wide. + * `tensor_write` tensor memory write operations are complete and visible CTA-wide. + * `all` convenience alias for `["local", "global_read", "global_write", "tensor_read", "tensor_write"]`. + + Multiple address spaces can be combined (e.g. `local|tensor_write`). `none` cannot be combined with other address spaces. + + Example: + + ```mlir + ttg.barrier local + ttg.barrier local|global_read|global_write + ``` + }]; + + let arguments = (ins TTG_AddrSpace:$addrSpace); + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Returns true if the barrier includes all of the given address spaces. + /// For example, hasAddrSpaces(Local | GlobalRead) returns true only if + /// both Local and GlobalRead are set. + bool hasAddrSpace(AddrSpace space) { + return bitEnumContainsAll(getAddrSpace(), space); + } + bool hasLocal() { return hasAddrSpace(AddrSpace::Local); } + bool hasGlobalRead() { return hasAddrSpace(AddrSpace::GlobalRead); } + bool hasGlobalWrite() { return hasAddrSpace(AddrSpace::GlobalWrite); } + bool hasTensorRead() { return hasAddrSpace(AddrSpace::TensorRead); } + bool hasTensorWrite() { return hasAddrSpace(AddrSpace::TensorWrite); } + }]; +} + +def TTG_WarpIdOp : TTG_Op<"warp_id", [Pure]> { + let summary = "Return the GPU warp ID"; + + let description = [{ + This operation returns the GPU warp ID. This can translate to reading + hardware registers if there are, or just thread ID divided by warp size. + + The `omitUniformHint` attribute is indicating in NVIDIA backend whether to + omit emitting nvvm.shfl.sync idx 0 for LLVM. + }]; + + let arguments = (ins UnitAttr:$omitUniformHint); + let results = (outs I32:$result); + + let assemblyFormat = "attr-dict"; +} + +#endif // TRITONGPU_OPS diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td new file mode 100644 index 0000000000..a0415b62c6 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td @@ -0,0 +1,23 @@ +#ifndef TRITON_GPU_TYPE_INTERFACES +#define TRITON_GPU_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +// Interface dynamically attached to RankedTensorType and MemDescType. +def TTG_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", + "mlir::Attribute", "getEncoding", (ins)>, + InterfaceMethod<"Returns element type", + "mlir::Type", "getElementType", (ins)>, + InterfaceMethod<"Returns the type shape", + "llvm::ArrayRef", "getShape", (ins)>, + InterfaceMethod<"Returns the tensor or buffer rank", + "int64_t", "getRank", (ins)>, + InterfaceMethod<"Returns the element type bit width", + "int64_t", "getElementTypeBitWidth", (ins)>, + ]; +} + +#endif // TRITON_GPU_TYPE_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 0000000000..b99b26ef8a --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,86 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", "async.token", []> { + let summary = "async token type"; + let description = [{ + `ttg.async.token` is a type returned by an asynchronous operation. + It is used to establish an SSA-based link between async operations + and operations that group or synchronize the async operations. + }]; +} + +// Memory descriptor type. +def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::gpu::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + ArrayRefParameter<"int64_t">:$allocShape + ); + + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape()); + } + + bool hasRank() const { return true; } + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + "llvm::ArrayRef":$allocShape + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape); + }]> + + ]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 0000000000..cfad8be199 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,14 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..6be94d1a8a --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h new file mode 100644 index 0000000000..9ae9322841 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h @@ -0,0 +1,17 @@ + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCINGUTILS_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCINGUTILS_H_ + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::gpu { +BlockedEncodingAttr +buildCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + triton::gpu::CGAEncodingAttr cgaLayout, + SmallVector shapePerCTA); +} // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCINGUTILS_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h new file mode 100644 index 0000000000..f06f85e58a --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h @@ -0,0 +1,47 @@ +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::gpu { + +class DecomposeScaledBlocked : public OpRewritePattern { +public: + DecomposeScaledBlocked(MLIRContext *context, PatternBenefit benefit) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(DotScaledOp scaledDotOp, + PatternRewriter &rewriter) const override; + +protected: + FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType, + PatternRewriter &rewriter) const; + TypedValue scaleTo16(PatternRewriter &rewriter, + TypedValue scale, + FloatType computeType) const; + TypedValue + broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp, + ModuleOp mod, TypedValue scale, + int dim) const; + TypedValue maskNan(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, + TypedValue mxfp, + TypedValue scale, + int dim) const; + virtual TypedValue scaleArg(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, + int opIdx, + FloatType computeType) const; + TypedValue + cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx, + TypedValue v) const; + TypedValue + extendAndBroadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp, + TypedValue &scale, + FloatType computeType, RankedTensorType dstType, + int opIdx) const; + static SmallVector getTransposeOrder(int rank); +}; + +void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns, + int benefit); + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h new file mode 100644 index 0000000000..b289de5593 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Tools/LinearLayout.h" +#include + +namespace mlir::triton::gpu { + +// Given the result |dstLayout|, infer the source layout that we should use for +// global load if we propagate through op def chain of |defOp|. Returns +// std::nullopt if fails to infer or cannot reach a global load. +std::optional> +inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp); +std::optional> +inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp); + +} // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_LAYOUT_PROPAGATION_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h new file mode 100644 index 0000000000..58e5290c29 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h @@ -0,0 +1,83 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_ + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { + +class OpBuilder; +class DominanceInfo; + +namespace scf { +class ForOp; +} // namespace scf +namespace triton::nvidia_gpu { + +//===----------------------------------------------------------------------===// +// MMA Pipeline Analysis +//===----------------------------------------------------------------------===// + +// Given an MMAv5 operation in a loop, determine if its accumulator can be +// multibuffered. +bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp); + +// Returns true if the MMA operation requires acc multi-buffering when +// pipelined. +bool requiresAccMultiBuffering(MMAv5OpInterface mma, scf::ForOp forOp); + +// Returns true if there are loads from tmem after the MMA operation. +bool hasLoadsAfterMMA(MMAv5OpInterface mma, scf::ForOp forOp); + +// Helper class to determine if the operands of an MMA operation are +// pipelineable. +class MMAv5PipelineableOperandsHelper { +public: + MMAv5PipelineableOperandsHelper( + MMAv5OpInterface mmaOp, scf::ForOp forOp, + std::function isLoadToBePipelined) + : mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) { + run(); + } + + bool isPipelineable = false; + // If true, the existing operand loads are all been found and their + // pipelineability has been determined. + bool isOperandsStateDetermined = false; + SmallVector unpipelineableOperandDefs; + +private: + MMAv5OpInterface mmaOp; + scf::ForOp forOp; + std::function isLoadToBePipelined; + void run(); + bool isOperandPipelineable(Value v, Operation *&foundDef); +}; + +bool areScalesPipelineable(TCGen5MMAScaledOp scaledOp, scf::ForOp forOp); +bool isOperandPipelineableBase( + Value v, scf::ForOp forOp, Operation *&foundDef, + std::function isPipelineable = + [](Operation *) { return false; }, + std::function isLoadToBePipelined = + [](Operation *) { return false; }); + +//===----------------------------------------------------------------------===// +// MMA Pipeline Rewriters +//===----------------------------------------------------------------------===// + +// Create a new TMEMAllocOp to use for the pipelined MMA operation. It is +// optionally multi-buffered based on the number of stages. +TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp, + bool multiBufferred, int numStages); + +// Return true if the accumulator of an mma in subsequent iterations is either +// independent from the previous iteration (overwritten) or completely reused, +// without read-modify-write. +// Otherwise, we can not pipeline the MMA, as we need to insert a wait after the +// mma to read back the accumulator for RMW. +bool hasAccReadModifyWrite(MMAv5OpInterface mma, scf::ForOp forOp); + +} // namespace triton::nvidia_gpu +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Partition.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Partition.h new file mode 100644 index 0000000000..6c5b287f0c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Partition.h @@ -0,0 +1,127 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_PARTITION_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_PARTITION_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +class Operation; +class OpOperand; +class OpResult; +class Region; +namespace scf { +class ForOp; +} // namespace scf +} // namespace mlir + +//===----------------------------------------------------------------------===// +// PartitionSet +//===----------------------------------------------------------------------===// + +namespace mlir::triton::gpu { +// A partition has a stage and contains some operation. The stage of a +// partition determines how many cycles the partition's outputs are buffered +// relative to its consumers. +class Partition { +public: + Partition(int idx, int stage) : idx(idx), stage(stage) { + assert(idx >= 0 && "A partition index must be nonnegative."); + } + + int getIndex() const { return idx; } + int getStage() const { return stage; } + ArrayRef getOps() const { return ops; } + void addOp(Operation *op) { ops.push_back(op); } + bool hasOp(Operation *op) const; + bool empty() const { return ops.empty(); } + + // Iterate the inputs of the partition. Input values are those that originate + // from a different partition or a previous iteration of the current + // partition. E.g. partition B(i) may have inputs from A(i) or B(i-1). Note + // that the same value may be visited more than once. + void iterateInputs(scf::ForOp loop, + function_ref callback) const; + // Iterate the outputs of the partition. Output values are those that are + // consumed by a different partition or a future iteration of the current + // partition. E.g. partition A(i) may have outputs to B(i) or A(i+1). Note + // that the same value may be visited more than once. + void + iterateOutputs(scf::ForOp loop, + function_ref callback) const; + // Iterate the defining ops of the inputs to the partition in the current and + // previous iterations, including the distance in the past. + void iterateDefs(scf::ForOp loop, + function_ref callback) const; + // Iterate the uses of all outputs of the partition in the current iteration + // and in future iterations, including the distance in the future. + void iterateUses( + scf::ForOp loop, + function_ref callback) const; + +private: + void setIndex(int idx) { this->idx = idx; } + + // The partition number. + int idx; + // The stage of the partition. + int stage; + // The ops in the partition. + SmallVector ops; +}; + +// A partition set divides a loop into multiple partitions. Ops in a loop are +// assigned at most one partition. A partition set represents asynchronous +// execution of the loop body, where partitions may execute simultaneously. +class PartitionSet { +public: + // Get WarpSpecialization tag + int getTag() const { return tag; } + + // Create a new partition with a stage. + Partition *addPartition(unsigned stage); + + // Get the partition at the index. + Partition *getPartition(unsigned idx); + // Get the partition at the index. + const Partition *getPartition(unsigned idx) const; + // Return an iterator range over the partitions. + auto getPartitions() { return llvm::make_pointee_range(partitions); } + // Return an iterator range over the partitions. + auto getPartitions() const { return llvm::make_pointee_range(partitions); } + // Get the number of partitions. + unsigned getNumPartitions() const { return partitions.size(); } + + // Deserialize a partition set from an `scf.for` op using the attributes + // tagged on operations in its body. + static FailureOr fromLoop(scf::ForOp loop); + + // Debug dump the partition set. + LLVM_DUMP_METHOD void dump() const; + + // Utility to be used when the op is known to belong to one partition + Partition *getPartition(Operation *op); + +private: + // WarpSpecialization tag + int tag; + // Partitions are numbered [0, N). + SmallVector> partitions; +}; + +// Annotate the op with the partition index or indices, and add the op +// to the partitions it belongs to. +void setPartition(Operation *op, Partition *partition); +void setPartition(Operation *op, const SetVector &partitions); +// Annotate the op with the partition indices. It should only be used in a pass +// which does not work with Partition instances and iterate* functions, since +// it does not keep the op attributes and the op list of a partition in sync. +void setPartition(Operation *op, const SetVector &partitionIds); +void setPartitionOutputs(Operation *op, + ArrayRef> partitionOutputsIds); +void setWarpSpecializeTag(Operation *op, int tag); + +} // namespace mlir::triton::gpu + +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_PARTITION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h new file mode 100644 index 0000000000..baa16421c1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h @@ -0,0 +1,49 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H +#define TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H + +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton::gpu { + +class Partition; + +using StageCluster = std::optional>; + +// Get the stage and cluster for an operation, if it has one assigned. +void setStageCluster(OpBuilder &b, Operation *op, StageCluster stageCluster); +StageCluster getStageCluster(Operation *op); + +struct PartitionBuilder : public ImplicitLocOpBuilder { + using ImplicitLocOpBuilder::ImplicitLocOpBuilder; + + Value intCst(int value, unsigned width = 32); + Value boolCst(bool value); + + void assignPartition(Operation *op, Partition &partition); + + template + auto createInto(Partition &partition, StageCluster stageCluster, + Args &&...args) { + auto op = create(std::forward(args)...); + assignPartition(op, partition); + setStageCluster(*this, op, stageCluster); + return op; + } +}; + +template +OpT createInto(OpBuilder &b, Location loc, + std::optional> partitionSet, + StageCluster stageCluster, Args &&...args) { + auto op = OpT::create(b, loc, std::forward(args)...); + if (partitionSet) { + setPartition(op, *partitionSet); + setStageCluster(b, op, stageCluster); + } + return op; +} + +} // namespace mlir::triton::gpu + +#endif // TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionSchedulingUtility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionSchedulingUtility.h new file mode 100644 index 0000000000..830d7d1151 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PartitionSchedulingUtility.h @@ -0,0 +1,460 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PARTITION_SCHEDULING_UTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_PARTITION_SCHEDULING_UTILITY_H_ + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir::triton::gpu::partition_scheduling_detail { + +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +class Graph; +class Node; + +enum Flags : uint8_t { + NONE = 0, + MANUAL = 1 << 0, + LOAD = 1 << 1, + STORE = 1 << 2, + MMA = 1 << 3, + TMEM = 1 << 4, + SFU = 1 << 5, + VIEW = 1 << 6, +}; + +inline Flags &operator|=(Flags &lhs, Flags rhs) { + return lhs = static_cast(lhs | rhs); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, Flags flags); + +Flags getNodeFlags(Node *node); + +size_t computeCost(Operation *op); + +inline bool isViewOp(Operation *op) { + return isa(op) || + op->hasTrait(); +} + +class Partition { +public: + explicit Partition(Graph *graph) : graph(graph) {} + void add(Node *node); + void remove(Node *node) { nodes.remove(node); } + void addFlag(Flags flag) { flags |= flag; } + Flags getFlags() const { return flags; } + const SetVector &getNodes() const { return nodes; } + bool empty() const { return nodes.empty(); } + + size_t getStage() const { + if (flags & Flags::MMA) + return 1; + return 0; + } + size_t getCost() const { return cost; } + + static void merge(Partition *lhs, Partition *rhs); + + void dump() const; + + std::optional id; + +private: + Graph *graph; + Flags flags = Flags::NONE; + size_t cost = 0; + SetVector nodes; +}; + +class Port { +public: + Port() = default; + Port(Node *node, size_t idx) : node(node), idx(idx) {} + Node *getNode() const { return node; } + size_t getIdx() const { return idx; } + + bool operator==(const Port &other) const { + return node == other.node && idx == other.idx; + } + +private: + Node *node = nullptr; + size_t idx = 0; +}; + +} // namespace mlir::triton::gpu::partition_scheduling_detail + +namespace llvm { +template <> +struct DenseMapInfo { + static inline mlir::triton::gpu::partition_scheduling_detail::Port + getEmptyKey() { + return {}; + } + + static inline mlir::triton::gpu::partition_scheduling_detail::Port + getTombstoneKey() { + return mlir::triton::gpu::partition_scheduling_detail::Port(0, 1); + } + + static unsigned getHashValue( + const mlir::triton::gpu::partition_scheduling_detail::Port &port) { + return std::hash()( + port.getNode()) ^ + std::hash()(port.getIdx()); + } + + static bool + isEqual(const mlir::triton::gpu::partition_scheduling_detail::Port &lhs, + const mlir::triton::gpu::partition_scheduling_detail::Port &rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +namespace mlir::triton::gpu::partition_scheduling_detail { + +using InputPort = Port; +using OutputPort = Port; + +class Edge { +public: + Edge() = default; + Edge(OutputPort from, InputPort to) : from(from), to(to) {} + + OutputPort getFrom() const { return from; } + InputPort getTo() const { return to; } + + Node *getFromNode() const { return from.getNode(); } + size_t getFromIdx() const { return from.getIdx(); } + + Node *getToNode() const { return to.getNode(); } + size_t getToIdx() const { return to.getIdx(); } + + bool isDataValue() const; + bool crossesPartitions() const; + Type getType() const; + size_t getSize() const; + +private: + OutputPort from; + InputPort to; +}; + +class Node { +public: + explicit Node(Operation *op) : op(op), cost(computeCost(op)) {} + + Node(Node *parent, Operation *op, size_t numInputs, size_t numOutputs) + : parent(parent), op(op), cost(computeCost(op)) { + inputs.resize(numInputs); + outputs.resize(numOutputs); + dataOutputs.resize(numOutputs); + } + + Node(Node *parent, Value value, size_t numInputs, size_t numOutputs) + : parent(parent), value(value) { + inputs.resize(numInputs); + outputs.resize(numOutputs); + dataOutputs.resize(numOutputs); + } + + Node *addNode(Operation *op, size_t inputs, size_t outputs) { + return nodes.emplace_back(new Node(this, op, inputs, outputs)).get(); + } + + Node *addNode(Value value, size_t inputs, size_t outputs) { + return nodes.emplace_back(new Node(this, value, inputs, outputs)).get(); + } + + void walk(const std::function &fn) { + std::function do_walk = [&](Node *node) { + for (auto &child : node->getNodes()) { + fn(child.get()); + do_walk(child.get()); + } + }; + do_walk(this); + } + + static void addEdge(OutputPort from, InputPort to) { + from.getNode()->addOutputEdge(from.getIdx(), to); + to.getNode()->addInputEdge(to.getIdx(), from); + } + + static void removeEdge(Edge edge) { + edge.getFromNode()->removeOutputEdge(edge.getFromIdx(), edge.getTo()); + edge.getToNode()->removeInputEdge(edge.getToIdx(), edge.getFrom()); + } + + void addDefines(Node *node) { defines.push_back(node); } + + void addInputEdge(size_t idx, OutputPort port) { + assert(idx < inputs.size()); + inputs[idx] = port; + } + + void removeInputEdge(size_t idx, OutputPort port) { + assert(idx < inputs.size()); + inputs[idx] = {}; + } + + void addOutputEdge(size_t idx, InputPort port) { + assert(idx < outputs.size()); + outputs[idx].push_back(port); + } + + void removeOutputEdge(size_t idx, InputPort port) { + assert(idx < outputs.size()); + for (auto it = outputs[idx].begin(); it != outputs[idx].end(); it++) { + if (*it == port) { + outputs[idx].erase(it); + break; + } + } + } + + Node *getParent() const { return parent; } + bool isOp() const { return op; } + bool isValue() const { return !op; } + Operation *getOp() { return op; } + Value &getValue() { + assert(isValue()); + return value; + } + const SmallVector &getDefines() const { return defines; } + + const SmallVector> &getNodes() const { return nodes; } + + size_t getNumInputs() const { return inputs.size(); } + size_t getNumOutputs() const { return outputs.size(); } + + const SmallVector &getInputs() const { return inputs; } + const SmallVector> &getOutputs() const { + return outputs; + } + SmallVector getOutputsFromPort(size_t idx) const { + return outputs[idx]; + } + + SmallVector getInEdges() { + SmallVector result; + size_t idx = 0; + for (auto input : inputs) { + result.push_back(Edge(input, InputPort(this, idx))); + idx++; + } + return result; + } + + size_t getNumInDataEdges() { + size_t count = 0; + size_t idx = 0; + for (auto input : inputs) { + Edge edge(input, InputPort(this, idx)); + if (edge.isDataValue()) + count++; + idx++; + } + return count; + } + + SmallVector getOutEdges() { + SmallVector result; + size_t idx = 0; + for (auto outputs : this->outputs) { + for (auto output : outputs) + result.push_back(Edge(OutputPort(this, idx), output)); + idx++; + } + return result; + } + + size_t getNumOutDataEdges() { + size_t count = 0; + size_t idx = 0; + for (auto output : dataOutputs) { + if (output) + count += outputs[idx].size(); + idx++; + } + return count; + } + + void setDataValue(size_t idx) { + assert(idx < dataOutputs.size()); + dataOutputs[idx] = true; + } + + bool isDataValue(size_t idx) { + assert(idx < dataOutputs.size()); + return dataOutputs[idx]; + } + + bool isData() { + // node is data if it consumes/produces a data value + if (std::any_of(dataOutputs.begin(), dataOutputs.end(), + [](bool x) { return x; })) { + return true; + } + for (auto input : inputs) + if (input.getNode() && input.getNode()->isDataValue(input.getIdx())) + return true; + return false; + } + + bool containsData() { + // node contains data if a data op appears in its region + for (auto &node : getNodes()) { + if (node->isData()) + return true; + if (node->containsData()) + return true; + } + return false; + } + + bool inLoopBody() { + if (op) + return op->getParentOfType(); + if (auto blockArg = dyn_cast(value)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + return isa(parentOp) || + parentOp->getParentOfType(); + } + auto result = cast(value); + auto op = result.getOwner(); + return isa(op) || op->getParentOfType(); + } + + bool containsLoopBody() { + for (auto &node : getNodes()) { + if (node->inLoopBody()) + return true; + if (node->containsLoopBody()) + return true; + } + return false; + } + + std::string getLabel() { + if (op) + return op->getName().getStringRef().str(); + if (auto blockArg = dyn_cast(value)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) + return "arg " + std::to_string(blockArg.getArgNumber()); + if (isa(parentOp)) { + if (blockArg.getArgNumber() == 0) + return "ind var"; + return "iter arg " + std::to_string(blockArg.getArgNumber() - 1); + } + return "?"; + } + auto result = cast(value); + return "result " + std::to_string(result.getResultNumber()); + } + + void setPartition(Partition *partition) { + for (auto current_partition : partitions) + current_partition->remove(this); + partitions.clear(); + partitions.insert(partition); + partition->add(this); + } + + void addPartition(Partition *partition) { + partitions.insert(partition); + partition->add(this); + } + + void addPartitions(const SetVector &partitions) { + this->partitions.insert(partitions.begin(), partitions.end()); + for (auto partition : partitions) + partition->add(this); + } + + bool hasPartition() const { return !partitions.empty(); } + + Partition *getPartition() const { + assert(partitions.size() == 1); + return *(partitions.begin()); + } + + const SetVector &getPartitions() const { return partitions; } + + bool hasCost() const { return cost > 0; } + size_t getCost() const { + assert(hasCost()); + return cost; + } + + void dump() { llvm::errs() << "node '" << getLabel() << "'\n"; } + +private: + Node *parent = nullptr; + Operation *op = nullptr; + Value value; + size_t cost = 0; + + SmallVector> nodes; + SmallVector defines; + + SmallVector inputs; + SmallVector> outputs; + SmallVector dataOutputs; + + SetVector partitions; +}; + +class Graph { +public: + explicit Graph(Operation *op) : root(new Node(op)) {} + + Node *getRoot() { return root.get(); } + + Partition *addPartition() { + auto partition = partition_storage.emplace_back(new Partition(this)).get(); + partitions.insert(partition); + return partition; + } + + void erasePartition(Partition *partition) { + assert(partition->empty()); + partitions.remove(partition); + } + + const SetVector &getPartitions() const { return partitions; } + + void walk(const std::function &fn) { + std::function do_walk = [&](Node *node) { + for (auto &child : node->getNodes()) { + fn(child.get()); + do_walk(child.get()); + } + }; + do_walk(root.get()); + } + +private: + std::unique_ptr root; + SetVector partitions; + SmallVector> partition_storage; +}; + +struct VisualizationInfo { + DenseMap partition_ids; + DenseMap partition_colors; +}; + +void visualize(std::string key, std::string filename, std::string title, + Graph *graph, VisualizationInfo &info); + +} // namespace mlir::triton::gpu::partition_scheduling_detail + +#endif // TRITON_TRITONGPU_TRANSFORMS_PARTITION_SCHEDULING_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 0000000000..242a23b7ba --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace gpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 0000000000..6cb5c963f4 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,343 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Applies software pipelining to loops in the module based on number of stages. + This may convert some load into asynchronous loads, and multi-buffer the data. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages">, + Option<"dumpIntermediateSteps", "dump-intermediate-steps", + "bool", /*default*/"false", + "Dump intermediate steps"> + ]; +} + +def TritonGPUAssignLatencies : Pass<"tritongpu-assign-latencies", "mlir::ModuleOp"> { + let summary = "assign latencies to interesting ops ahead of pipelining"; + + let description = [{ + The `tritongpu-assign-latencies` pass assigns latencies to latency ops based + on the number of stages. + }]; + + let options = [ + Option<"numStages", "num-stages", "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUScheduleLoops : Pass<"tritongpu-schedule-loops", "mlir::ModuleOp"> { + let summary = "software pipeline loop scheduling"; + + let description = [{ + The `tritongpu-schedule-loops` pass performs scheduling for loop pipelining + for loops with latency ops. + }]; +} + +def TritonGPUHoistTMEMAlloc : Pass<"tritongpu-hoist-tmem-alloc", "mlir::ModuleOp"> { + let summary = "Hoist TMEM allocations out of the loop. This is a preparation for the loop lowering."; + + let description = [{ + Hoist TMEM allocations out of the loop. Keep the values in the TMEM as much as possible. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + let options = [ + Option<"hoistOutOfIf", "hoist-out-of-if", + "bool", /*default*/"false", + "Hoist TMEM allocations out of if statements"> + ]; +} + +def TritonGPUTestPipelineLowerLoop : Pass<"tritongpu-test-pipeline-lower-loop", "mlir::ModuleOp"> { + let summary = "test lowering a loop for software pipelining"; + + let description = [{ + This is a test pass that tests `lowerLoop` method of `TritonGPUPipeline`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUFuseNestedLoops : Pass<"tritongpu-fuse-nested-loops", "mlir::ModuleOp"> { + let summary = "fuse nested loops for pipelining"; + + let description = [{ + The `tritongpu-fuse-nested-loops` pass will analyze loop nests in the module + that need to be pipelined and fuse them into a single loop. This composes + with the pipeliner to pipeline loop nests. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::arith::ArithDialect", + "mlir::ub::UBDialect", + ]; +} + +def TritonGPUAutomaticWarpSpecialization : Pass<"tritongpu-automatic-warp-specialization", "mlir::ModuleOp"> { + let summary = "automatic warp specialization of loops"; + + let description = [{ + The `tritongpu-automatic-warp-specialization` pass applies automatic + warp specialization to eligible loops in the module. The pass will analyze + the loops in the kernel and attempt to create a partition schedule, which + if successful lowers the loop by duplicating it into `ttg.warp_specialize` + partition regions. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "triton::nvws::NVWSDialect" + ]; + + let options = [ + Option<"numStages", "num-stages", "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"> { + let summary = "split scheduled loops into `ttg.warp_specialize`"; + + let description = [{ + The `tritongpu-partition-loops` pass will analyze the loops in the module + that have been scheduled for warp specialization and split them into + `ttg.warp_specialize` partition regions. This requires no SSA dependencies + between any of the partitions. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "triton::nvws::NVWSDialect" + ]; +} + +def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps", "mlir::ModuleOp"> { + let summary = "optimize the number of warps assigned to partitions"; + + let description = [{ + The `tritongpu-optimize-partition-warps` pass will analyze the partitions + of `ttg.warp_specialize` ops and attempts to reduce the number of warps + assigned to them and optimize the register usage of the partitions. + }]; +} + +def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir::ModuleOp"> { + let summary = "warp specialization partitioning pass"; + + let description = [{ + The `tritongpu-partition-scheduling` analyzes the loads, MMAs, and other + operations in a loop that is meant to be warp specialized and determines + which partitions to assign to each operation. + }]; +} + +def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { + let summary = "Emulate dot-product tensor core precision using TF32s or BF16s"; + + let description = [{ + Generic pass to emulate/decompose f32 `DotOp` instructions. + * Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385. + * Decompose fp32 `DotOp` instructions into BF16 operations. + See https://arxiv.org/abs/1904.06376 + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + let options = [ + Option<"emuTF32", "emu-tf32", + "bool", /*default*/"false", + "whether to handle InputPrecision TF32xN for Nvidia GPUs"> + ]; +} + +def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { + let summary = "prefetch"; + + let description = [{ + This pass attempts to prefetch from shared memory the operands (A and B) + of a `tt.dot`, when this operation is located in a loop. + Decompose `DotOp` instructions in loops into several finer-grained `DotOp` + that may have their operands constructed at the end of the previous + iteration. + Transformations are performed in five different places: + 1. The pass emits a prologue to the loop where the data for the first + loop iteration are prefetched. + 2. The loop arguments are extended with the new prefetched values. + 3. The dotOp parameters is updated with the new args. + 4. The prefetch operations for the next iteration are added to the loop. + 5. The yieldOp is updated by adding the prefetched values for the next + iteration. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., Nvidia tensor cores) + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "fuse transpositions"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of + hardware-accelerated transpositions. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"hoistLayoutConversion", "hoist-layout-conversion", + "bool", /*default*/"true", + "whether to move conver to dot operand earlier pass elementwise ops"> + ]; +} + +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + The pass analyses loads/stores with type `tensor>` or + `tt.ptr>` and replaces the layouts of these operations with + coalesced layouts, i.e. cache friendly access patterns. + Layout conversions are inserted before and after the load/store op + to maintain consistency with the rest of the program. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + + +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; + + let description = [{ + The purpose of this pass is to rewrite the `ConvertLayoutOps` to reduce + the number of operations and to prefer favorable layouts like + `BlockedEncodingAttr` layout for "expensive" loads and stores + (good for coalescing) and `NvidiaMmaEncodingAttr` otherwise + (good for tensor ops). + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + The aim of this pass is to reduce cross-thread communication for certain + operations, like reductions, reshapes, and gathers. + + For reduction operations, this pass attempts to adjust the reduction size + (or layout) to avoid splitting the reduction operation between multiple + threads. Currently, this pass only optimizes reduction yielded by loop to be + thread-local until after the loop completes. + + For gathers, this pass will attempt to pick an optimized layout for gather + operations in the module. This is determined based on the shapes of the + gather operands as well as their existing layouts. The pass applies + heuristics to determine when it is appropriate to assign specific layouts + and trigger their respective codegen paths. For now, the pass only attempts + to apply layouts that result in warp-synchronous gathers. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReduceDataDuplication: Pass<"tritongpu-reduce-data-duplication", "mlir::ModuleOp"> { + let summary = "Reduce data duplication in register by decomposing convert[distributed -> dotOperand] " + "into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and reuse #shared tensors"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and-if", "mlir::ModuleOp"> { + let summary = "Combine tensor select and if"; + + let description = "For select instruction that uses the same condition as the if instruction in the same block " + "this pass combines the select into the if instruction, making the select operands returned by the " + "then/else yields."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Replace accumulator zero-initialization with the flag indicating first use of the accumulator"; + + let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization " + "of the accumulator with the flag indicating the first use of the accumulator."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> { + let summary = "Improve coalescing for async global to local copies"; + + let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than " + "the blocked encoding's sizePerThread, this pass improves coalescing by clipping the " + "sizePerThread value"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h new file mode 100644 index 0000000000..4851bfe001 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h @@ -0,0 +1,111 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operations in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the prologue/epilogue. + bool supportDynamicLoops = false; + + /// If set, use this function to emit the predicate stage ops instead of the + /// default one. + using EmitPredicateStageFnType = std::function; + EmitPredicateStageFnType emitPredicateStageFn = nullptr; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +Value emitPredicateForStage(RewriterBase &rewriter, Value inductionVar, + Value upperBound, Value step, uint64_t maxStage, + uint64_t stage); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h new file mode 100644 index 0000000000..5700a366fc --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -0,0 +1,189 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include +#include + +namespace mlir { +class DominanceInfo; +class ImplicitLocOpBuilder; +namespace triton { + +static const char *kNumStagesAttrName = "tt.num_stages"; +static const char *kDisallowAccMultiBufferAttrName = + "tt.disallow_acc_multi_buffer"; +static const char *kWarpSpecializeAttrName = "tt.warp_specialize"; +static const char *kLoopStageAttrName = "loop.stage"; +static const char *kLoopClusterAttrName = "loop.cluster"; +static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage"; +class CoarseSchedule; +class ModuleAxisInfoAnalysis; +//===----------------------------------------------------------------------===// +// Hoisting Utilities +//===----------------------------------------------------------------------===// + +// By default, an operation can be hoisted if it is pure scalar operation. +bool isPureScalarOp(Operation *op); + +// Given a set of values and a reference operation, return true if all of the +// values dominate the reference operation OR a set of "trivial" operations can +// be moved before the reference operation such that the value set dominates the +// reference operation. +// +// Returns false if it is not possible to make the values dominate the reference +// operation. The function determines "trivial"-ness with the given callback. +// By default, it determines that memory-effect-free and scalar operations are +// trivial. +bool getDominatingValueSetOpsToHoist( + DominanceInfo &domInfo, Operation *refOp, ArrayRef valueSet, + llvm::SetVector &toHoist, + function_ref canHoist = isPureScalarOp, + function_ref canUseArg = [](BlockArgument) { + return false; + }); + +// Hoist the given set of operations above the reference operation. +void hoistOpsBefore(Operation *refOp, + const llvm::SetVector &toHoist); +// Hoist the given set of operations before the iterator. +void hoistOpsBefore(Block *block, Block::iterator it, + const llvm::SetVector &toHoist); + +//===----------------------------------------------------------------------===// +// Sinking Utilities +//===----------------------------------------------------------------------===// + +// Sink a value redefinition into a block, provided that the block is dominated +// by `in` and postdominated by `out`. +Value sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out, + Block *block); + +//===----------------------------------------------------------------------===// +// Loop Pipelining Utilities +//===----------------------------------------------------------------------===// + +bool loopHasDistGreaterThanOne(scf::ForOp forOp); +bool isOuterLoop(scf::ForOp forOp); + +/// Function to mask operations during scheduling. +Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); + +/// Wrap the operation into a MaskOp using the provided predicate, enabling high +/// level predication abstraction during pipelining. +Operation *wrapInMaskOp(RewriterBase &rewriter, Operation *op, Value pred); + +// Utilize high level predication abstraction to perform optimizations before +// lowering to predicated operations +void resolveMaskOp(ModuleOp moduleOp); + +// Return true if the given ForOp has the attribute +// `tt.disallow_acc_multi_buffer` set to true. +bool getDisallowAccMultiBuffer(scf::ForOp forOp); + +// Return the definition of the given value. If the value is a loop-carried +// dependency, return the definition and the distance to it. +std::pair getDefinitionAndDistance(scf::ForOp forOp, + Value value); +// Return the defining op of the given value, if the Value is an argument of the +// loop return the associated defining op in the loop and its distance to the +// Value. +std::pair getDefiningOpAndDistance(scf::ForOp forOp, + Value value); + +// Return maximum length of the vectorized copy between registers and shared +// memory for the given tensor type and shared encoding. +int getCopyVecBytes(RankedTensorType registerTy, + gpu::SharedEncodingTrait sharedEnc); + +bool canBeConvertedToAsyncLoad( + triton::LoadOp loadOp, triton::ModuleAxisInfoAnalysis &axisInfoAnalysis); + +// Serialize the latencies of the operations in the loops into the latency +// attribute. +void serializeLatencies(ModuleOp module, DenseMap &opLatency); + +// Serialize the self latencies of the operations in the loops into the +// self_latency attribute. +void serializeSelfLatencies(ModuleOp module, + DenseMap &opSelfLatency); + +// Deserialize the latencies of the operations in the loops from the attribute. +DenseMap deserializeLatencies(Operation *op); + +// Create an allocation for multibuffered scalars. +Value createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type, + unsigned numBuffers); +// Create an allocation and init the mbarriers. +Value createBarrierAlloc(Operation *op, int numBarriers, int arriveCount = 1); +// Create an allocation that can hold distance number of tensor shapes. +Value createAlloc(Operation *insertBefore, RankedTensorType ty, Location loc, + gpu::SharedEncodingTrait sharedEnc, unsigned distance); + +// Determine if the operation is a TMA load. +bool isTMALoad(Operation *op); + +// Determine if the operation can be lowered to an async load. +bool canBeAsyncLoad(Operation *op); + +// Look for consecutive wait ops and combine them into a single wait op. +void combineRedundantWaitOps( + llvm::SmallSetVector &waitOps); + +// Get the type of the view of a multi-buffered tensor value. +gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy, + bool mutableMemory = true); + +// Get a mutable, multi-buffered version of the given memdesc type, with +// multiplicity "depth". +gpu::MemDescType getMultiBufferedType(gpu::MemDescType memDescType, + int32_t depth); + +// Get a generic shared encoding for a tensor. +gpu::SharedEncodingTrait getSharedEncoding(RankedTensorType ty); +// Get a shared encoding for a tensor based on its uses. +gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp); + +// Get the number of stages to pipeline the loop with, if it is explicitly +// specified. +int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages); + +// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a +// single buffer slice (leading dimension equal to 1), at the given index. +TypedValue +createSingleBufferView(OpBuilder &builder, Value alloc, Value idx); +// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a +// single buffer slice (leading dimension equal to 1), at the given index. +TypedValue +createSingleBufferView(OpBuilder &builder, Value alloc, int idx); + +Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter, + Value modulus, Value zero, Value one, + Value *outWrapCond = nullptr); + +scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule); + +DenseSet +getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp, + std::function filter = nullptr); + +// Return the "first" op in terms of the stage and cluser ordering +Operation * +getFirstUseOfPipelinedOp(ArrayRef ops, scf::ForOp forOp, + CoarseSchedule &schedule, + std::function filterUse = nullptr); + +// Return the "last" op in terms of the stage and cluser ordering +Operation * +getLastUseOfPipelinedOp(ArrayRef ops, scf::ForOp forOp, + CoarseSchedule &schedule, + std::function filterUse = nullptr); + +// Clean up attributes passing over schedules across stages in pipelining +void removePipeliningAttributes(ModuleOp moduleOp); +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Schedule.h new file mode 100644 index 0000000000..1c277257eb --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -0,0 +1,285 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include + +namespace mlir { +namespace triton { + +namespace gpu { + +/// Lower the loops to prepare them for pipeline expansion. +void lowerLoops(ModuleOp moduleOp); + +bool hasGpuBarriers(scf::ForOp forOp); +bool isSafeToPipeline(scf::ForOp forOp); +llvm::MapVector> +loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis, + int numStages, bool filterSmall = true); + +}; // namespace gpu + +/// Pipeline the TMA stores in the loop. +bool pipelineTMAStores(scf::ForOp forOp); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +/// Post process the pipelined loop by updating the wait ops with the right +/// number of groups in flight. +void updateWaits(ModuleOp module); + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + using const_iterator = decltype(orderClusters)::const_iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + const_iterator begin() const { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + const_iterator end() const { return orderClusters.end(); } + size_t size() const { return orderClusters.size(); } + void clear() { orderClusters.clear(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + + bool isBefore(iterator a, iterator b) const { + if (a == b) + return false; + for (auto it = begin(); it != end(); ++it) { + if (it == a) + return true; + if (it == b) + return false; + } + llvm::report_fatal_error( + "One or both clusters not found in clusters list!"); + } + }; + + CoarseSchedule() = default; + CoarseSchedule(int numStages) : numStages(numStages) {} + ClusterList clusters; + using Cluster = ClusterList::iterator; + using ClusterHash = size_t; + + llvm::MapVector> opToStageAndCluster; + + void setNumStages(int numStages) { this->numStages = numStages; } + int getNumStages() const { return numStages; } + + void insert(Operation *op, int stage, Cluster cluster) { + if (stage >= numStages) { + numStages = stage + 1; + } + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + bool insertMinimum(Operation *op, int stage, Cluster cluster); + + bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg, bool insertIfEarlier = false); + + // Remove empty stages and clusters from the schedule, adjusting the maximum + // number of stages as appropriate. + void shrinkToFit(); + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) const { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + auto find(Operation *op) const { return opToStageAndCluster.find(op); } + + // Split the cluster containing op into two clusters, one containing all + // operations before the op and one containing op and all operations after the + // op. Return the cluster containing op and all operations after the op. + Cluster splitClusterBefore(Operation *op, scf::ForOp forOp); + + // Check if op a will show up before op b in the final unrolled code. + bool isOpBefore(Operation *a, Operation *b) const; + + // Check if op a is in earlier cluster than op b. + bool isOpInEarlierCluster(Operation *a, Operation *b) const; + + // Check if op a is in the same cluster as op b. + bool isOpInSameCluster(Operation *a, Operation *b) const; + + SmallVector> + getOpsInOrder(scf::ForOp forOp) const; + std::vector> + createFinalSchedule(scf::ForOp forOp) const; + + bool empty() const { return opToStageAndCluster.size() == 0; } + auto end() const { return opToStageAndCluster.end(); } + auto begin() const { return opToStageAndCluster.begin(); } + + // Set based on CoarseSchedule. + void serialize(scf::ForOp &forOp) const; + // Create a CoarseSchedule based on forOp's . + // If normalizeClusterId is true, clusters [minClusterId, maxClusterId] will + // be remapped to [0, maxClusterId - minClusterId]. + // If false, it won't remap and clusters [0, maxClusterId] will be created. + LogicalResult deSerialize(scf::ForOp &forOp, bool normalizeClusterId = true); + + static ClusterHash hashCluster(Cluster cluster) { + return reinterpret_cast(&*cluster); + } + + LLVM_DUMP_METHOD void dump(); + + // ============================================================ + // Linearized Schedule Iterator API + // ============================================================ + + /// A stateful iterator over operations in linearized schedule order. + /// Operations are yielded lazily in order: (stage, cluster, + /// IR-order-within-cluster). + /// + /// The iterator is circular and stage-aware: it starts from initialOp at its + /// stage, traverses to the end of clusters, wraps around to the beginning, + /// and when it reaches initialOp again, increments the stage limit. An op is + /// only yielded if its stage <= currStageLimit. The iterator stops when it + /// reaches initialOp and currStageLimit >= numStages. + class LinearizedIterator { + public: + /// Construct an iterator for the given forOp and schedule. + /// The iterator starts at initialOp and wraps around circularly with + /// stage-based filtering. + LinearizedIterator(scf::ForOp forOp, const CoarseSchedule &schedule, + Operation *initialOp); + + // Standard iterator operations + LinearizedIterator &operator++(); + LinearizedIterator operator++(int); + Operation *operator*() const; + bool operator==(const LinearizedIterator &other) const; + bool operator!=(const LinearizedIterator &other) const; + + bool isEnd() const { return atEnd; } + + /// Advance the iterator to the next operation that satisfies the optional + /// predicate. Returns the found operation, or std::nullopt if not found. + /// The iterator position is updated to the found operation (or end). + std::optional + findNext(std::function predicate = nullptr) { + while (!isEnd()) { + Operation *op = *(*this); + ++(*this); + if (!predicate || predicate(op)) { + return op; + } + } + return std::nullopt; + } + + private: + /// Advance to the next valid operation in the schedule. + void advanceToNextScheduledOp(); + + scf::ForOp forOp; + const CoarseSchedule *schedule; + ClusterList::const_iterator clusterIt; + ClusterList::const_iterator clusterBegin; + ClusterList::const_iterator clusterEnd; + Block::iterator opIt; + Block::iterator opEnd; + Operation *currentOp = nullptr; + Operation *initialOp = nullptr; + int currStageLimit = 0; + int maxStages = 0; + bool atEnd = false; + }; + + /// Get a circular iterator over the linearized schedule starting from + /// initialOp. The iterator will traverse from initialOp to the end, wrap + /// around to the beginning, and stop when it reaches initialOp again. + LinearizedIterator linearized(scf::ForOp forOp, Operation *initialOp) const { + return LinearizedIterator(forOp, *this, initialOp); + } + +private: + int numStages = 0; +}; + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule); + +class OpBuilderForStage : public mlir::ImplicitLocOpBuilder, + public OpBuilder::Listener { +public: + explicit OpBuilderForStage(Location loc, Operation *op, + CoarseSchedule &schedule) + : ImplicitLocOpBuilder(loc, op, this), schedule(schedule) { + if (auto it = schedule.find(op); it != schedule.end()) + std::tie(stage, cluster) = it->second; + } + + void setStageCluster(std::pair stageCluster) { + stage = stageCluster.first; + cluster = stageCluster.second; + } + + void notifyOperationInserted(Operation *op, InsertPoint previous) { + if (stage && cluster) + schedule.insert(op, *stage, *cluster); + } + +private: + std::optional stage; + std::optional cluster; + CoarseSchedule &schedule; +}; + +namespace gpu { +void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule); +void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue); +} // namespace gpu + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 0000000000..1d51f170c3 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs, bool enableSourceRemat); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); + + // Determine whether the operation is currently legal. I.e. it has layouts + // assigned to its tensor operands and results. + static bool isDynamicallyLegal(Operation *op, + const TypeConverter &typeConverter); +}; + +namespace impl { +LogicalResult convertGatherScatterOp(Operation *op, ValueRange operands, + OpOperand &xOffsetsMutable, + const TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); +} // namespace impl + +// Generic pattern for converting a TMA gather or scatter operation. +template +struct GatherScatterOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpT op, typename OpT::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return impl::convertGatherScatterOp(op, adaptor.getOperands(), + op.getXOffsetsMutable(), + *this->getTypeConverter(), rewriter); + } +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h new file mode 100644 index 0000000000..8085febca3 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -0,0 +1,303 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { +class DominanceInfo; +class PostDominanceInfo; + +namespace triton { +class ModuleAxisInfoAnalysis; +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SwizzledSharedEncodingAttr; +} +} // namespace triton + +// Return a tuple of two or three entries representing the shape of the +// instruction used to perform a matrix multiplication operation. +// Version = 1: +// Version = 2: <1, m, n> +// Version = 3: +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + Type type, int numWarps); + +// Return true if the Load uses block pointer. +bool isLoadFromTensorPtr(triton::LoadOp op); + +// Gets the order of a tensor from its contiguity. Places the dimensions with +// the largest contiguity as the inner most dimension. If the contiguity is +// all ones, returns the order {dim - 1, dim - 2, ..., 0} +SmallVector +getOrderFromContiguity(const SmallVector &contiguity); + +// Return the operand used to access the memory in the operation +Value getMemAccessPtr(Operation *op); + +// Return bitwidth of tensor element +unsigned getElementBitWidth(RankedTensorType type); + +// Calculate the optimal number of elements per thread for a given operation +// along an axis with greatest continuity. +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis, + ArrayRef shape); + +// Returns whether the op is a "view op", i.e. doesn't move any data +bool isView(Operation *op); + +// Returns whether the op is a "noop op", i.e. has one input and one output +// and lowers to llvm as the identity function (returns the input) +bool isNoop(Operation *op); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +Attribute inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +Attribute inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature( + OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements); +scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands); +[[nodiscard]] scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + +// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::WhileOp replaceWhileOpWithNewSignature( + OpBuilder &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements); +scf::WhileOp replaceWhileOpWithNewSignature(OpBuilder &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes); + +// Replace IfOp with a new IfOp with extra results operands. The YieldOp is not +// updated and needs to be updated separately for the bodies to be correct. +scf::IfOp replaceIfOpWithNewSignature( + OpBuilder &rewriter, scf::IfOp loop, TypeRange newResultTypes, + SmallVectorImpl> &replacements); +scf::IfOp replaceIfOpWithNewSignature(OpBuilder &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes); + +// Append the given |newOperands| to the |forOp|'s yield op. +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands); + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping); + +/// For a given \p root value with desired layout \p rootEncoding, get the +/// backward slice of values that would have to be recreated to produce the +/// value of \p root with that layout (without an intervening layout +/// conversion). The traversal stops once we reach an operand that meets one of +/// the following: +/// 1. has the desired layout +/// 2. \p getExistingConversion returns an existing converted value +/// 3. \p stopPropagation returns true for an op. +/// The slice is returned in \p slice, and the desired layout of each value in +/// the slice is stored in \p layouts. +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr, + std::function getExistingConversion = + nullptr, + unsigned maxSliceSize = 0); + +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); + +// Return true if the op is a pure elementwise_inline_asm op with a single +// operand and single result. +bool isPureUnaryInlineAsm(Operation *op); + +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + +// Read the amd target from the module attributes +std::optional getAMDArch(Operation *module); + +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); + +// Convert \param op to use \param encoding attribute. +// Skips operands if they're in shared encoding. +Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op); + +// Returns the original memory allocation for a memdesc value +triton::gpu::LocalAllocOp findShmemAlloc(Value operand); + +// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis +SmallVector +getMMAsWithMultiBufferredOperands(scf::ForOp forOp, + SmallVector &mmaOps); + +// Given a list of ops, find the naerest common dominator of all ops or return +// null if one could not be found. The ops are allowed to be in different +// regions. The result op is not necessarily one of the ops in the list. +Operation *findNearestCommonDominator(ArrayRef ops, + DominanceInfo &domInfo); +// Given a list of ops, find the naerest common postdominator of all ops or +// return null if one could not be found. The ops are allowed to be in different +// regions. The result op is not necessarily one of the ops in the list. +Operation *findNearestCommonPostDominator(ArrayRef ops, + PostDominanceInfo &postDomInfo); + +/// Visit the operands of `op` and the operands of any nested ops defined +/// outside of `op`. +void visitNestedOperands(Operation *op, + function_ref visitor); +/// Visit the operands of `op` and the operands of any nested ops defined +/// outside of `op`. +void visitNestedOperands(Operation *op, function_ref visitor); +/// Get the operands of `op` and the operands of any nested ops defined outside +/// of `op`. +SetVector getNestedOperands(Operation *op); + +// Erase the given loop carried values from the loop, where `loop` is replaced +// with a new loop. +void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices); +} // namespace mlir + +namespace mlir::triton { +/// Replace all uses of `oldUse` with `val` and propagate the type if needed. +/// This is useful when we need to change a memory descriptor from immutable to +/// mutable. +/// The callback is invoked for each pair of an old and a cloned memdesc op +/// as the type is propagated. +void replaceUsesAndPropagateType( + OpBuilder &builder, Operation *oldUse, Value val, + std::function callback = nullptr); + +/// Replace all uses of `old` with a local load from `alloc` unless the use is a +/// `ttg.local_alloc` with a matching shared encoding, in which case the shared +/// memory is forwarded directly into the use. Returns the `ttg.local_load` if +/// it created one. +triton::gpu::LocalLoadOp +replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old, + TypedValue alloc, + TypedValue token = {}); + +// Return true if the value comes from a load or a block argument. +// This will skip convert layouts and memdesc views. +// This is a helper useful to know if value is likely to come from shared memory +// after converting loads into async loads. +bool comesFromLoadOrBlockArg(Value v); + +// For structured control flow ops, returns the values associated with the +// `resultIdx`th result. +SmallVector getTiedArgs(Operation *op, int resultIdx); + +// Verifies the provided memory descriptor type used for barrier allocation +LogicalResult verifyBarrierType(Operation *op, + mlir::triton::gpu::MemDescType barrierType); + +// Get a boolean if the Value is an arith::ConstantOp +std::optional getBoolFromConstant(Value cst); + +} // namespace mlir::triton + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h new file mode 100644 index 0000000000..afb7dde2c1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h @@ -0,0 +1,24 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_WARPSPECIALIZATION_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_WARPSPECIALIZATION_H_ + +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace scf { +class ForOp; +} // namespace scf +namespace triton::gpu { +// This is the final step to prepare a loop for warp specialization. This takes +// a loop with a partition schedule and rewrites the loop such that all SSA +// dependencies between partitions are passed through shared memory and +// multibuffers them according to partition stages. +LogicalResult rewritePartitionDependencies(scf::ForOp &loop); +// Given a loop where the partitions' inputs and outputs have been fully +// rewritten to be reference semantic, partitiong the loop into a +// `ttg.warp_specialize` by duplicating the loop for each partition and +// rematerializing, as necessary, operations in the root partition. +LogicalResult partitionLoop(scf::ForOp loop); +} // namespace triton::gpu +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_WARPSPECIALIZATION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonInstrument/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt new file mode 100644 index 0000000000..2af09f9046 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonInstrumentDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tti) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tti) +add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc) + +add_public_tablegen_target(TritonInstrumentTableGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Dialect.h new file mode 100644 index 0000000000..e0fcf61b44 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Dialect.h @@ -0,0 +1,14 @@ +#ifndef TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_ + +// TritonInstrument depends on Triton and TritonGPU +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "triton/Dialect/TritonInstrument/IR/OpsEnums.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonInstrument/IR/Dialect.h.inc" +#include "triton/Dialect/TritonInstrument/IR/Ops.h.inc" + +#endif // TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h new file mode 100644 index 0000000000..7325447516 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -0,0 +1,227 @@ +#ifndef TRITONINSTRUMENT_FUNCTIONBUILDER_H +#define TRITONINSTRUMENT_FUNCTIONBUILDER_H + +#include "triton/Dialect/TritonInstrument/IR/Utility.h" + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +class ImplicitLocOpBuilder; +class ModuleOp; +class Operation; +class RankedTensorType; +class Type; +class Value; +} // namespace mlir + +namespace mlir::triton { +class FuncOp; + +namespace instrument { + +class ManglingArgs { +public: + using Arg = std::variant; + + ManglingArgs() = default; + ManglingArgs(const ManglingArgs &) = default; + ManglingArgs(ManglingArgs &&) = default; + ManglingArgs &operator=(const ManglingArgs &) = default; + ManglingArgs &operator=(ManglingArgs &&) = default; + + ManglingArgs(std::initializer_list args) : args(args) {} + + ~ManglingArgs() = default; + + template void append(T arg) { args.push_back(arg); } + + template void append(ArrayRef arg) { + for (auto &a : arg) { + args.push_back(a); + } + } + + void append(ManglingArgs &other) { + args.append(other.args.begin(), other.args.end()); + } + + std::string mangleArg(Arg arg) const { + if (auto type = std::get_if(&arg)) { + auto hash = static_cast(mlir::hash_value(*type)); + return std::string("_T") + llvm::utohexstr(hash); + } else if (auto intVal = std::get_if(&arg)) { + return std::string("_I") + std::to_string(*intVal); + } else if (auto stringVal = std::get_if(&arg)) { + return *stringVal; + } + llvm_unreachable("Unsupported argument type"); + } + + std::string mangle(std::string baseName, int numWarps) const { + std::string name = "__triton_consan_"; + name += baseName; + name += "_nw" + std::to_string(numWarps); + for (auto arg : args) + name += mangleArg(arg); + return name; + } + +private: + SmallVector args; +}; + +/// Utility to mangle helper function names produced by the instrumentation +/// passes. The mangled name encodes the base name, number of warps and the +/// participating types. +std::string mangleInstrumentHelperName(const std::string &baseName, + int numWarps, + llvm::ArrayRef types); + +class FunctionBuilder { +public: + FunctionBuilder(ModuleOp module, AuxDataMap &auxData) + : module(module), auxData(auxData) {} + + // setWaiting: mark the base thread as waiting on the given barrier phase and + // record that phase for deadlock detection. + void createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread, + Value phase, Value pred, Operation *insertPoint); + // clearWaiting: clear the waiting flag and stored phase for the base thread. + void createClearWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread, + Value pred, Operation *insertPoint); + // checkAllActiveWaiting: assert that not all active threads are waiting on + // matching barrier phases. + void createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, int activeMask, + Value pred, Operation *insertPoint); + // initBarrierState: Initialize the tracked barrier state to phase 0 and set + // both the initial and current arrival counts. + void createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, + int count, Operation *insertPoint); + // verifyBarrierArrive: Check that applying the arrive count would not drive + // the tracked current count negative. Triggers an assertion on failure. + void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar, + int count, Value pred, + Operation *insertPoint); + // updateBarrierState: Apply an arrive count to the tracked barrier state, + // toggling the phase when the count reaches zero and reloading the current + // count from the initial count. + void createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, + int count, Value pred, + Operation *insertPoint); + // setWriteVisibility: Set the write visibility for a buffer. Marks the buffer + // as visible to the threads set in threadMask. Clears out any other threads + // from the visibility bitmask. We know this is safe because there cannot be + // outstanding writes to this buffer at this point. + void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint); + // setReadVisibility: add the threads set in threadMask to the buffer's read + // visibility bitmask. + void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint); + // clearWriteTracking: clear all the information about threads writing to a + // buffer. + void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, Value pred, + MemType memType, Operation *insertPoint); + // clearReadVisibility: clear the read visibility for a buffer. + void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, Value pred, + MemType memType, Operation *insertPoint); + // clearReadTracking: clear the read tracking for a buffer. + void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, Value pred, MemType memType, + Operation *insertPoint); + // trackVisibleWrites: snapshot buffers currently visible to the thread into + // the tracking table for a barrier. + void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, + int thread, Value pred, MemType memType, + Operation *insertPoint); + // trackVisibleReads: snapshot buffers currently visible to the thread into + // the read tracking table for a barrier. + void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, + int thread, Value pred, MemType memType, + Operation *insertPoint); + // transferVisibleWrites: transfer write visibility tracked by a barrier to + // all threads in threadMask. + void createTransferVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, + uint64_t threadMask, Value pred, + MemType memType, Operation *insertPoint); + // transferVisibleReads: transfer read visibility tracked by a barrier to all + // threads in threadMask. + void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, + uint64_t threadMask, Value pred, + MemType memType, Operation *insertPoint); + // verifyWriteVisibility: ensure the thread either sees the latest write or no + // other thread is writing the buffer. + void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, int thread, + StringRef operandName, Value pred, + MemType memType, Operation *insertPoint); + // verifyReadVisibility: ensure all reads from the buffer are visible to the + // thread. + void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, int thread, + StringRef operandName, Value pred, + MemType memType, Operation *insertPoint); + // copyWriteVisibility: replicate the write visibility bit of sourceThread to + // every destination thread in destMask. + void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, + uint64_t destMask, Value pred, + MemType memType, Operation *insertPoint); + // copyReadVisibility: replicate the read visibility row of sourceThread to + // every destination thread in destMask. + void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, + uint64_t destMask, Value pred, + MemType memType, Operation *insertPoint); + // stageAccessForCommit: mark the buffer as staged (value -1) in the + // outstanding commit table for this thread. + void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, int thread, Value pred, + MemType memType, + CommitKind::Kind commitKind, + Operation *insertPoint); + // commitAccesses: convert staged entries to 1 and increment outstanding + // commits greater than zero for the committing thread. + void createCommitAccessesCall(ImplicitLocOpBuilder &b, int thread, Value pred, + CommitKind::Kind commitKind, + Operation *insertPoint); + // clearOutstandingCommitsTransferWrites: clear entries farther than + // outstandingNum from the thread and set write visibility for threads in + // transferThreadMask. + void createClearOutstandingCommitsTransferWritesCall( + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, + int outstandingNum, Value pred, CommitKind::Kind commitKind, + MemType memType, Operation *insertPoint); + // clearOutstandingCommitsTransferReads: clear entries farther than + // outstandingNum from the thread and set read visibility for threads in + // transferThreadMask. + void createClearOutstandingCommitsTransferReadsCall( + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, + int outstandingNum, Value pred, CommitKind::Kind commitKind, + MemType memType, Operation *insertPoint); + // checkOutstandingCommits: assert that the outstanding commit row for the + // buffer is zero before the access described by pendingAccessType. + void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf, + uint32_t length, int thread, + StringRef pendingAccessType, + Value pred, MemType memType, + CommitKind::Kind commitKind, + Operation *insertPoint); + +private: + ModuleOp module; + AuxDataMap &auxData; +}; + +} // namespace instrument +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md new file mode 100644 index 0000000000..c7e05eef1d --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -0,0 +1,86 @@ +# Triton Instrument Dialect and Concurrency Sanitizer (ConSan) + +### Overview + +ConSan instruments Triton IR to detect illegal concurrent accesses to shared and Tensor Core memory under warp specialization. It tracks per-buffer visibility of reads and writes across threads, models barrier-based synchronization, and models commit-count–based synchronization (cp.async, wgmma). + +Auxiliary state is kept in distributed tensors and global scratch memory, with types created on-demand per warp-specialization partition. + +### Thread model + +- Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions). +- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads. +- Total logical threads: 48. Bitmasks are sized to the next power of two: 64. + +Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience. + +## Auxiliary data structures + +All types are generated on-demand (per partition) based on: + +- B: number of tracked buffers (power-of-two padded) +- K: number of mbarriers (power-of-two padded) +- T_bits: 64 (bitmask width) +- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers) + +“tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts. + +- buffers (tensor, ): Base pointers of all (sub)buffers per memory space +- barriers (tensor, ): Pointers of all mbarriers +- writeVisibility (scratch, ): Per-buffer bitmask. Bit i set ⇒ thread i can see latest completed write to that buffer +- readVisibility (scratch, ): Per-buffer, per-thread lanes. Each lane stores a 64-bit mask of other threads whose reads are visible to that lane’s thread +- writeTracking (scratch, ): Map buffers → barriers tracking writes (boolean stored in i8) +- readTracking (scratch, ): Map buffers → barriers tracking reads (bitmask of threads) +- barrierStates (scratch, ): Packed barrier metadata. Bit 0 stores the current phase, bits [1..8] the initial arrival count, bits [9..16] the current arrival count. The verifier checks underflow before updating, and flips the phase when the current count reaches zero. +- waiting (scratch, ): Per-barrier bitfield describing waiting threads. Each base thread gets two bits: bit (2 * thread + 0) is the waiting flag, bit (2 * thread + 1) stores the phase the thread is waiting on. +- outstandingCommits (scratch, ): Per-buffer, per-base-thread commit counters for cp.async and wgmma + +## Visibility and legality rules + +- Reads are legal iff the reading thread sees the most recent write to the buffer (writeVisibility). There can be only one write in-flight. +- Writes are legal iff the writing thread sees both all prior writes and all reads completed for that buffer. + +ConSan enforces these via two checks emitted before memory ops: + +- experimental_verify_write_visibility: “no one else is writing, or I can see the write” +- experimental_verify_read_visibility: “my read-visibility lane is a superset of the OR of all lanes” + +## Barrier-based synchronization + +ConSan separates “tracking” from “visibility transfer”: + +- At memory ops that are tracked by a barrier (loads/stores, some TMEM ops): + - experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer. + - experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier. +- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes. +- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent. + +### Barrier phase/count tracking + +- experimental_init_barrier_state(barrier, count, barrierStates) initializes the per-barrier state with phase = 0 and both initial/current arrival counts = `count`. +- experimental_verify_barrier_arrive(barrier, count, barrierStates) checks that subtracting `count` from the current arrival count would not underflow. The codegen emits an assert if it would. +- experimental_update_barrier_state(barrier, count, barrierStates) applies the arrive: subtracts `count`, flips the phase when the count reaches zero, and reloads the current count from the initial count. + +### Deadlock detection + +ConSan records which phase each thread is waiting on: + +- experimental_set_waiting(barrier, baseThread, phase, barriers, waiting) sets the waiting flag for `baseThread` and stores the requested `phase`. The flag/phase bits share the waiting bitfield (two bits per base thread). +- experimental_check_all_active_waiting(activeMask, barriers, waiting, barrierStates) filters waiting threads to those whose stored phase matches the current barrier phase. If all active threads are waiting on matching phases, it raises a deadlock assert. +- experimental_clear_waiting(barrier, baseThread, barriers, waiting) clears the waiting bits for `baseThread`. Each wait clears its own state after the wait completes. + +## Commit-count–based synchronization + +Some hardware ops synchronize via “number of outstanding commits” rather than mbarriers. + +- Stage: experimental_stage_access_for_commit marks the current thread’s buffer lane with -1 (staged) in outstandingCommits[B x 16]. +- Commit: experimental_commit_accesses turns -1 into 1 and increments positive entries for the committing thread column. +- Wait (cp.async): experimental_clear_outstanding_commits_set_write(thread, commits, writeVisibility, N) clears entries with count > N for the current thread, and sets the writeVisibility bit for rows where any thread’s entry was cleared. +- Wait (wgmma): experimental_clear_outstanding_commits_set_read(thread, commits, readVisibility, N) clears entries with count > N for the current thread, and sets the readVisibility bit for rows where any thread’s entry was cleared. + +Legality checks for commit-count flows: + +- For writes to shared memory affected by cp.async: experimental_check_outstanding_commits(buffer, commits, "async_copy_global_to_shared") asserts the row for the buffer is all zeros (no pending writes), across all base-thread columns. +- For reads of wgmma operands in shared memory: experimental_check_outstanding_commits(buffer, commits, "warpgroup_mma operand read") asserts the row is all zeros (no pending reads). + +Note: The check op has no “thread” operand; it inspects the whole row for the buffer. diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td new file mode 100644 index 0000000000..ab8702defb --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td @@ -0,0 +1,15 @@ +#ifndef TRITONINSTRUMENT_ATTR_DEFS +#define TRITONINSTRUMENT_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +def TT_MemTypeAttr : I32EnumAttr< + "MemType", "", + [ + I32EnumAttrCase<"SHARED_MEM", 0, "shared_mem">, + I32EnumAttrCase<"TENSOR_MEM", 1, "tensor_mem">, + ]> { + let cppNamespace = "::mlir::triton::instrument"; +} + +#endif // TRITONINSTRUMENT_ATTR_DEFS diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td new file mode 100644 index 0000000000..6a7f3eed62 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td @@ -0,0 +1,11 @@ +#ifndef TRITONINSTRUMENT_DIALECT +#define TRITONINSTRUMENT_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonInstrument_Dialect : Dialect { + let name = "tti"; + let cppNamespace = "::mlir::triton::instrument"; +} + +#endif // TRITONINSTRUMENT_DIALECT diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td new file mode 100644 index 0000000000..b74c45c33e --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -0,0 +1,96 @@ +#ifndef TRITONINSTRUMENT_OPS +#define TRITONINSTRUMENT_OPS + +include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Ops +// + +class TTI_Op traits = []> : + Op { +} + +def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [MemoryEffects<[MemWrite]>]> { + let summary = "assert the condition within the current thread"; + let description = [{ + Assert that the condition is true given all the values are available in the current thread. + If the condition is false, the message is printed, and the program is aborted. + If check_any is true, any of the values in the condition must be true. Otherwise, all the + values in the condition must be true. + }]; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message, BoolAttr:$check_any); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + + +def TTI_ExperimentalBufferDescriptorsOp + : TTI_Op<"experimental_buffer_descriptors", [Pure]> { + let summary = "define an array of buffer descriptors"; + let description = [{ + Create a tensor of buffer descriptors packing 32-bit pointer offsets and + 32-bit lengths into 64-bit elements. + }]; + let arguments = (ins DenseI32ArrayAttr:$offsets, DenseI32ArrayAttr:$lengths, + TT_MemTypeAttr:$memType); + let results = (outs TT_Tensor:$result); + let assemblyFormat = [{ + $offsets `,` $lengths `,` $memType attr-dict `:` type($result) + }]; +} + +def TTI_ExperimentalMemDescToI32Op : TTI_Op<"experimental_memdesc_to_i32", [Pure]> { + let summary = "Convert a memdesc into its base pointer as i32"; + let description = [{ + Extract the base pointer from the given memdesc and return it as a 32-bit + integer. This can be used to compare the memdesc against tensors of barrier + pointers maintained by the concurrency sanitizer. + }]; + let arguments = (ins TTG_MemDescType:$memdesc); + let results = (outs I32:$result); + let builders = [ + OpBuilder<(ins "Value":$memdesc), [{ + build($_builder, $_state, $_builder.getI32Type(), memdesc); + }]> + ]; + let assemblyFormat = "$memdesc attr-dict `:` type($memdesc)"; +} + + +// ===== Critical section lock ops ===== + + +def TTI_ExperimentalLockAcquireOp : TTI_Op<"experimental_lock_acquire", [MemoryEffects<[MemWrite]>]> { + let summary = "Acquire a lock."; + let description = [{ + Enter a critical section by acquiring a lock with single thread. + }]; + let arguments = (ins TT_PtrLike:$lock, Optional:$pred); + let assemblyFormat = [{ + $lock (`,` $pred^)? attr-dict `:` type($lock) + }]; +} + + +def TTI_ExperimentalLockReleaseOp : TTI_Op<"experimental_lock_release", [MemoryEffects<[MemWrite]>]> { + let summary = "Release a lock."; + let description = [{ + Leave a critical section by releasing a lock with single thread. + }]; + let arguments = (ins TT_PtrLike:$lock, Optional:$pred); + let assemblyFormat = [{ + $lock (`,` $pred^)? attr-dict `:` type($lock) + }]; +} + +#endif // TRITONINSTRUMENT_OPS diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Utility.h b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Utility.h new file mode 100644 index 0000000000..1ceba9db1b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/IR/Utility.h @@ -0,0 +1,101 @@ +#ifndef TRITONINSTRUMENT_UTILITY_H +#define TRITONINSTRUMENT_UTILITY_H + +#include "triton/Analysis/BufferRegion.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" + +#include + +namespace mlir::triton::instrument { + +constexpr int numMemTypes = getMaxEnumValForMemType() + 1; + +constexpr int NUM_THREADS = 16; +constexpr int TMA_THREAD_OFFSET = NUM_THREADS; +constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS; +constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS; +constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS); + +namespace CommitKind { +enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds }; +} + +Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc, + Value tensor, RankedTensorType tensorType); +Value createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc, + RankedTensorType tensorType); +Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor); +TypedValue createConstIntTensor(OpBuilder &builder, + Location loc, int64_t val, + RankedTensorType tensorType, + bool isSigned = false); +FuncOp getEntryPoint(ModuleOp module); +gpu::DistributedEncodingTrait +getSingleDimSliceEncoding(gpu::BlockedEncodingAttr encoding, int dim); + +struct ValueType { + Value value; + Type type; + + ValueType() = default; + ValueType(Value value, Type type) : value(value), type(type) {} + ValueType(std::pair value) + : value(value.first), type(value.second) {} +}; + +// Map from IR region to ConSan auxiliary data. Auxiliary data is a value +// and an optional type, for values that are stored in the scratch memory. +struct AuxDataMap { + struct RegionToValueMap { + DenseMap values; + ValueType at(Region *region) { + if (values.find(region) == values.end()) { + assert(false && "Region not found in AuxDataMap"); + } + return values[region]; + } + ValueType at(Operation *op) { + return at(getEnclosingParitionOrFunctionRegion(op)); + } + void insert(Region *region, ValueType value) { values[region] = value; } + bool empty() const { return values.empty(); } + + private: + Region *getEnclosingParitionOrFunctionRegion(Operation *op); + }; + + // Please see TritonInstrumentOps.td for more information on the auxiliary + // data structures. + RegionToValueMap buffers[numMemTypes]; + RegionToValueMap barriers; + RegionToValueMap barrierStates; + + RegionToValueMap writeVisibility[numMemTypes]; + RegionToValueMap writeTracking[numMemTypes]; + RegionToValueMap readVisibility[numMemTypes]; + RegionToValueMap readTracking[numMemTypes]; + RegionToValueMap commits[CommitKind::NumCommitKinds]; + RegionToValueMap aliasMatrices[numMemTypes]; + RegionToValueMap lock; + RegionToValueMap waiting; + std::array hasNonTrivialAliasing{}; + + void populateAndPassToWarpSpecialize(ModuleOp module); + +private: + void getBuffersAndBarriers( + ModuleOp module, + SmallVector, 2> &bufRegions, + SmallVector &barrierRegions); + void passToWarpSpecialize(triton::FuncOp func, ValueType value, + RegionToValueMap &map); + void createInWarpSpecialize( + triton::FuncOp func, RegionToValueMap &map, + std::function createFn); +}; + +} // namespace mlir::triton::instrument + +#endif // TRITONINSTRUMENT_UTILITY_H diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..672815ac4b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonInstrument) +add_public_tablegen_target(TritonInstrumentTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.h new file mode 100644 index 0000000000..c96c618e68 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.h @@ -0,0 +1,22 @@ +#ifndef TRITON_DIALECT_TRITONINSTRUMENT_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONINSTRUMENT_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace instrument { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h.inc" + +} // namespace instrument +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.td new file mode 100644 index 0000000000..cfd860e991 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonInstrument/Transforms/Passes.td @@ -0,0 +1,16 @@ +#ifndef TRITONINSTRUMENT_PASSES +#define TRITONINSTRUMENT_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonInstrumentConcurrencySanitizer: Pass<"tritoninstrument-concurrency-sanitizer", "mlir::ModuleOp"> { + let summary = "Add runtime verification of asynchronous operations"; + + let description = "Instrument the program with runtime verification of asynchronous operations."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::instrument::TritonInstrumentDialect"]; +} + +#endif // TRITON_INSTRUMENT_PASSES diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..b93aad2ba0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonNvidiaGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(TritonNvidiaGPUTypesIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) +mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOpInterfaces.td) +mlir_tablegen(TritonNvidiaGPUOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TritonNvidiaGPUOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(TritonNvidiaGPUOpInterfacesIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h new file mode 100644 index 0000000000..679d506d81 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +// TritonNvidiaGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc" + +namespace mlir::triton::nvidia_gpu::impl { +LogicalResult verifyMMAv5Op(Operation *op); +} // namespace mlir::triton::nvidia_gpu::impl + +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc" + +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" + +namespace mlir::triton::nvidia_gpu { + +constexpr static char AttrTwoCTAsName[] = "ttng.two-ctas"; + +inline bool getModuleTwoCTAs(ModuleOp mod) { + auto attr = mod->getAttrOfType(AttrTwoCTAsName); + return attr ? attr.getValue() : false; +} + +inline bool getModuleTwoCTAs(Operation *op) { + return getModuleTwoCTAs(op->getParentOfType()); +} + +struct TensorMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +struct TMemAllocation { + TMemAllocation(int numRows, int numCols) + : numRows(numRows), numCols(numCols) {} + int numRows; + int numCols; +}; + +// Used to describe the layout of the TMEM load/store instructions +enum class TMemAccessAtom { I32x32b, I16x64b, I16x128b, I16x256b, I16x32bx2 }; + +inline int getElementsPerThread(TMemAccessAtom atom) { + switch (atom) { + case TMemAccessAtom::I32x32b: + case TMemAccessAtom::I16x64b: + case TMemAccessAtom::I16x32bx2: + return 1; + case TMemAccessAtom::I16x128b: + return 2; + case TMemAccessAtom::I16x256b: + return 4; + } + llvm_unreachable("Unknown TMemAccessAtom"); +} + +inline const char *getOpShape(TMemAccessAtom atom) { + switch (atom) { + case TMemAccessAtom::I32x32b: + return "32x32b"; + case TMemAccessAtom::I16x64b: + return "16x64b"; + case TMemAccessAtom::I16x128b: + return "16x128b"; + case TMemAccessAtom::I16x256b: + return "16x256b"; + case TMemAccessAtom::I16x32bx2: + return "16x32bx2"; + } + llvm_unreachable("Unknown TMemAccessAtom"); +} + +LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, bool unpacked, + bool withWarp); + +TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType); + +SmallVector +getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps, + ArrayRef ctaSplit = {1, 1}); + +std::optional +getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, + gpu::MemDescType memType, int numWarps); + +SmallVector +getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, + gpu::MemDescType memType); + +bool isDistributedLayoutTMemCompatible(Operation *op, + RankedTensorType tensorType, + gpu::MemDescType memType); + +gpu::DistributedEncodingTrait +getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, + gpu::CGAEncodingAttr cgaLayout); + +std::optional +getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, + unsigned numWarps, + gpu::CGAEncodingAttr cgaLayout); + +} // namespace mlir::triton::nvidia_gpu + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h new file mode 100644 index 0000000000..3ae002a597 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h @@ -0,0 +1,37 @@ +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" + +#include +#include +#include + +namespace mlir::triton::nvidia_gpu { + +// Get the maximum number of registers per thread based on the context. This is +// by default 256, but it can be overridden by `ttg.maxnreg` set on the module +// or a contextual register limit set by the compiler on partitions. +int getContextualMaxNReg(Operation *op); +struct TMemLdStEncodingInfo { + TMemAccessAtom atom; + LinearLayout reps; + ColumnAction perm; + int numRegsPerMessage; + std::optional secondHalfOffset; + std::optional broadcast = std::nullopt; + bool unpacked = false; + unsigned vec = 1; + bool padding = false; +}; + +FailureOr +computeTMemLdStEncodingInfo(RankedTensorType regTy, gpu::MemDescType memTy, + int maxnreg, + std::function emitError = {}); + +} // namespace mlir::triton::nvidia_gpu + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_TENSORMEMORYUTILS_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td new file mode 100644 index 0000000000..d84a600173 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -0,0 +1,92 @@ +#ifndef TRITONNVIDIAGPU_ATTRDEFS +#define TRITONNVIDIAGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/EnumAttr.td" + +def TTG_TensorMemorySpace : AttrDef { + let mnemonic = "tensor_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to tensor memory. + The memory is laid out in blocks of size blockM x blockN. Each block is distributed + across TMEM 128 rows. + + Blocks are distributed along M dimension first and then N dimension. This is an arbitrary + convention that needs to be followed by operations reading/writing to TMEM. + + a tensor <128x128xf32> with blockM = 64 and blockN = 32 will be distributed as follows: + + \ col 0 1 31 32 64 96 127 + rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) ( 0, 32) ... ( 0, 64) ... ( 0, 96) ... ( 0, 127) + 1 + ... + 15 (15, 0) (15, 1) ... (15, 31) (15, 32) ... (15, 64) ... (15, 96) ... (15, 127) + 16 (64, 0) (64, 1) ... (64, 31) (64, 32) ... (64, 64) ... (64, 96) ... (64, 127) + ... + 31 (79, 0) (79, 1) ... (79, 31) (79, 32) ... (79, 64) ... (79, 96) ... (79, 127) + 32 (16, 0) (16, 1) ... (16, 31) (16, 32) ... (16, 64) ... (16, 96) ... (16, 127) + .. + 127 (127, 0) (127, 1) ... (127, 31) (127, 32) ... (127, 64) ... (127, 96) ... (127, 127) + }]; +} + +def TTNG_TMEMLoadReduceModifierAttr : I32EnumAttr< + "TMEMLoadReduceModifier", "", + [ + I32EnumAttrCase<"MIN", 1, "min">, + I32EnumAttrCase<"MAX", 2, "max">, + ]> { + let cppNamespace = "::mlir::triton::nvidia_gpu"; + let genSpecializedAttr = 0; +} +def TTNG_TMEMLoadReduceModifierEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TTG_TensorMemoryEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_encoding"; + let attrName = "triton.gpu.tensor_memory_encoding"; + let description = [{ + An encoding to represent the different way the tensor memory is laid out. + `colStride` describes the stride in elements along the column dimension, + that is, the stride between two elements in the same row. + When colStride is 1 the tensor memory is packed. When colStride > 1, the + tensor memory between elements is undefined. + `twoCTAs` indicates that the tensor memory is laid out for twoCTA mode, + i.e., `cta_group::2`. + }]; + let parameters = ( + ins + "unsigned":$blockM, + "unsigned":$blockN, + "unsigned":$colStride, + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN, + DefaultValuedParameter<"bool", "false">:$twoCTAs + ); + let genVerifyDecl = 1; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TTG_TensorMemoryScalesEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_scales_encoding"; + let attrName = "triton.gpu.tensor_memory_scales_encoding"; + let description = [{ + An encoding to represent the layout of tensor memory scales. + As described in the PTX doc, blocked scales in TMEM must be in a special layout. They are organized + as a multiple copies of "chunk", each of which having the size 32x4x4B. Moreover, such chunks are duplicated + over 4 warps to fill entire 128 rows of TMEM. This encoding indicates that a tensor in TMEM is in such a special + layout. + }]; + let parameters = ( + ins + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td new file mode 100644 index 0000000000..a185e18071 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -0,0 +1,49 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_DIALECT +#define TRITONNVIDIAGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonNvidiaGPU_Dialect : Dialect { + let name = "ttng"; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton Nvidia GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + "mlir::gpu::GPUDialect", + ]; + + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td new file mode 100644 index 0000000000..9003fa303b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td @@ -0,0 +1,74 @@ +#ifndef TRITON_NVIDIAGPU_OP_INTERFACES +#define TRITON_NVIDIAGPU_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> { + let description = [{ + This interface is implemented by MMAv5 dot and dot scaled ops. + }]; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + // We can add more methods as needed. + let methods = [ + InterfaceMethod<"Return the A operand.", + "::mlir::TypedValue<::mlir::triton::gpu::MemDescType>", + "getA">, + InterfaceMethod<"Return the B operand.", + "::mlir::TypedValue<::mlir::triton::gpu::MemDescType>", + "getB">, + InterfaceMethod<"Return the accumulator init flag.", + "::mlir::Value", + "useAccumulator">, + InterfaceMethod<"Set the accumulator init flag.", + "void", + "setUseAccumulator", + (ins "::mlir::Value":$flag)>, + InterfaceMethod<"Return the completion barriers of this MMAv5 op.", + "::mlir::ValueRange", + "getCompletionBarriers">, + InterfaceMethod<"Return the completion barrier predicates of this MMAv5 op.", + "::mlir::ValueRange", + "getCompletionBarrierPreds">, + InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.", + "void", + "addCompletionBarrier", + (ins "::mlir::Value":$barrier, "::mlir::Value":$pred)>, + InterfaceMethod<"Return the accumulator.", + "::mlir::TypedValue<::mlir::triton::gpu::MemDescType>", + "getAccumulator">, + InterfaceMethod<"Set the accumulator.", + "void", + "setAccumulator", + (ins "::mlir::Value":$accum)>, + InterfaceMethod<"Return the predicate of this op.", + "::mlir::Value", + "getPredicate">, + InterfaceMethod<"Set the predicate of this op.", + "void", + "setPredicate", + (ins "::mlir::Value":$pred)>, + InterfaceMethod<"Get the memory dependencies of the accumulator.", + "::mlir::Value", + "getAccDep">, + InterfaceMethod<"Get the mutable memory dependencies of the accumulator.", + "::mlir::MutableOperandRange", + "getAccDepMutable">, + InterfaceMethod<"Get the produced write dependency of the accumulator.", + "::mlir::Value", + "getToken">, + InterfaceMethod<"Indicate that this MMA op executes asynchronously.", + "void", + "setIsAsync", + (ins "bool":$isAsync)>, + InterfaceMethod<"Return true if this MMA op executes asynchronously.", + "bool", + "isAsync"> + ]; + + let verify = [{ + return ::mlir::triton::nvidia_gpu::impl::verifyMMAv5Op($_op); + }]; +} +#endif // TRITON_NVIDIAGPU_OP_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td new file mode 100644 index 0000000000..26b985f798 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -0,0 +1,933 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_OPS +#define TRITONNVIDIAGPU_OPS + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; +def TensorMemory : Resource<"::mlir::triton::nvidia_gpu::TensorMemory">; + +class TTNG_Op traits = []> : + Op { +} + +def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { + let arguments = (ins BoolAttr:$bCluster); + + let summary = "fence proxy async"; + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + +def TTNG_FenceMBarrierInitReleaseClusterOp : TTNG_Op< + "fence_mbarrier_init_release_cluster"> { + let summary = "fence mbarrier init release.cluster"; + + let assemblyFormat = "attr-dict"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + +def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + let assemblyFormat = "attr-dict"; + let hasVerifier = 1; +} + +def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; + let hasVerifier = 1; +} + +// +// WarpGroupDot Op +// +def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self"> +]> { + let summary = "warp group dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp + }]; + + let arguments = (ins + TTG_TensorOrMemDesc:$a, + TTG_MemDescType:$b, + TT_FpIntTensor:$c, + Optional:$useC, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc, + DefaultValuedAttr:$isAsync + ); + + let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` $useC^)? attr-dict + `:` type($a) `*` qualified(type($b)) `->` type($d) + }]; + + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + }]; + + let hasVerifier = 1; +} + +def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>, + PassthroughWaitLike]> { + let summary = "warp group dot wait"; + let arguments = (ins Variadic:$inputs, I32Attr:$pendings); + let results = (outs Variadic:$outputs); + let description = [{ + Waits until there are $pendings or fewer outstanding async dot operations. + + $inputs must be the tensors corresponding to the async dot ops that we're + waiting on. For example, if there are N pending async dot ops and we call + `warp_group_dot_wait 1`, then $inputs must be the result of the first dot op. + }]; + + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; + let hasVerifier = 1; +} + +def TTNG_InitBarrierOp : TTNG_Op<"init_barrier"> { + let summary = "Initialize a barrier in the given shared memory allocation."; + + let description = [{ + Initializes a shared memory allocation with mbarrier information. + `alloc` is a descriptor to the shared memory allocation. `count` is the + number of arrives expected by the barrier. + + This lowers to PTX mbarrier.init.shared::cta.b64. + }]; + + let arguments = (ins + Arg]>:$alloc, + I32Attr:$count + ); + let assemblyFormat = "$alloc `,` $count attr-dict `:` qualified(type($alloc))"; + let hasVerifier = 1; +} + +def TTNG_InvalBarrierOp : TTNG_Op<"inval_barrier"> { + let summary = "Invalidate a barrier allocation."; + + let description = [{ + Invalidate a barrier allocation so that it can be re-used. According to PTX + spec this has to be done before any reuse of the memory used by mbarrier. + + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval + }]; + + let hasVerifier = 1; + let arguments = (ins Arg]>:$alloc); + let assemblyFormat = "$alloc attr-dict `:` qualified(type($alloc))"; +} + +def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect"> { + let summary = "Signal a barrier of an expected number of bytes to be copied."; + + let description = [{ + This signal the barrier that `size` bytes are expected to be copied. The + associated barrier wait will block until the expected number of bytes are copied. + }]; + + let hasVerifier = 1; + let arguments = (ins + Arg]>:$alloc, + I32Attr:$size, + I1:$pred + ); + + let assemblyFormat = [{ + $alloc `,` $size attr-dict `,` $pred `:` qualified(type($alloc)) + }]; +} + +def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [AttrSizedOperandSegments]> { + let summary = "wait until the mbarrier phase completes."; + + let description = [{ + Blocks the program progress until the mbarrier object in `alloc` completes + its current phase. + + This lowers a waitloop using PTX instruction + mbarrier.try_wait.parity.shared::cta.b64. + + Accepts optional list of memory. If present, it is assumed that any of the + dependencies may be accessed until the barrier completes. + + The barrier behavior is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms + }]; + + let arguments = (ins + Arg, MemWrite]>:$alloc, + I32:$phase, + Optional:$pred, + Variadic:$deps + ); + + let builders = [ + OpBuilder<(ins "Value":$alloc, "Value":$phase), + [{ + build($_builder, $_state, alloc, phase, /*pred=*/static_cast(nullptr), /*deps=*/{}); + }]>, + OpBuilder<(ins "Value":$alloc, "Value":$phase, "Value":$pred), + [{ + build($_builder, $_state, alloc, phase, pred, /*deps=*/{}); + }]>, + OpBuilder<(ins "Value":$alloc, "Value":$phase, "ValueRange":$deps), + [{ + build($_builder, $_state, alloc, phase, /*pred=*/static_cast(nullptr), deps); + }]>, + ]; + + let assemblyFormat = [{ + $alloc `,` $phase (`,` $pred^)? (`deps` $deps^)? + attr-dict `:` qualified(type($alloc)) (`,` type($deps)^)? + }]; + let hasVerifier = 1; +} + +def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> { + let summary = "perform the arrive operation on an mbarrier"; + let description = [{ + The `ttng.arrive_barrier` operation performs the "arrive" operation on an + mbarrier object in shared memory. The operation requires a `count` attribute + of at least 1, and decreasing the pending arrival count of the mbarrier by + the specific count. + + The operation accepts an optional predicate. + + Example: + + ```mlir + ttng.arrive_barrier %barrier, 2 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.arrive_barrier %barrier, 1, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ``` + }]; + + let arguments = (ins + Arg, MemWrite]>:$alloc, + I32Attr:$count, + Optional:$pred + ); + + let assemblyFormat = [{ + $alloc `,` $count (`,` $pred^)? attr-dict `:` qualified(type($alloc)) + }]; + + let builders = [ + OpBuilder<(ins "Value":$alloc, "uint32_t":$count), [{ + return build($_builder, $_state, alloc, count, /*pred=*/Value()); + }]> + ]; + + let hasVerifier = 1; +} + +def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive"> { + let summary = "arrive on mbarrier once all previously issued copies are completed"; + let arguments = (ins + Arg]>:$barrier, + UnitAttr:$noIncrement + ); + let assemblyFormat = "$barrier attr-dict `:` qualified(type($barrier))"; +} + + +def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [AttrSizedOperandSegments]> { + let summary = "copy data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory + asynchronously. This is analogue to tt.load except the data are copied to + local memory pointed by the memory descriptor instead of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc`. + + The tensor mode is determined by the descriptor type: + - tt.tensordesc: TILED mode - Regular tiled tensor memory access + - See: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode + - ttng.tensordesc_im2col: IM2COL mode - Im2col mode for convolution-friendly access patterns + - In IM2COL mode, 'coord' is the coordinates in the input tensor + - For example, for a 4D tensor (NHWC), 'coord' is [batch_idx, channel_idx, h, w] + - In IM2COL mode, additional `offsets` must be provided (uint16 values) + - For 3D tensors (NWC): 1 offset (offset_w) + - For 4D tensors (NHWC): 2 offsets (offset_w, offset_h) + - For 5D tensors (NDHWC): 3 offsets (offset_w, offset_h, offset_d) + - General rule: number of offsets = coord.size() - 2 + - See: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode + + }]; + + let hasVerifier = 1; + let arguments = (ins + Arg]>:$desc, + Variadic:$coord, + Variadic:$offsets, + Arg]>:$barrier, + Arg]>:$result, + I1:$pred, + UnitAttr:$multicast, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let builders = [ + // Builder for TILED mode (no offsets required, attributes default to standard values) + OpBuilder<(ins "Value":$desc, "ValueRange":$coord, "Value":$barrier, + "Value":$result, "Value":$pred, + CArg<"bool", "false">:$multicast, + CArg<"triton::CacheModifier", "triton::CacheModifier::NONE">:$cache, + CArg<"triton::EvictionPolicy", "triton::EvictionPolicy::NORMAL">:$evict, + CArg<"bool", "false">:$isVolatile), [{ + build($_builder, $_state, desc, coord, /*offsets=*/ValueRange{}, barrier, + result, pred, multicast, cache, evict, isVolatile); + }]> + ]; + + let assemblyFormat = [{ + $desc `[` $coord `]` (`offsets` `=` `[` $offsets^ `]`)? $result `,` $barrier `,` $pred + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` qualified(type($desc)) `,` qualified(type($barrier)) `->` qualified(type($result)) + }]; +} + +def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global"> { + let summary = "copy data based on descriptor from local memory to global memory asynchronously"; + + let description = [{ + This operation copies data from local memory to global memory + asynchronously. This is analogue to tt.store except the data are copied from + local memory pointed by the memory descriptor instead of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc`. + }]; + + let arguments = (ins + Arg, MemWrite]>:$desc, + Variadic:$coord, + Arg]>:$src + ); + + let assemblyFormat = [{ + $desc `[` $coord `]` $src + attr-dict `:` qualified(type($desc)) `,` qualified(type($src)) + }]; + let hasVerifier = 1; +} + +def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "reduce result in gmem based on a TMA descriptor"; + + let description = [{ + This operation copies data from local memory to global memory + asynchronously, and atomically performs the specified reduction kind. + Atomicity is at the granularity of individual elements, and only relaxed + semantics are implied. + }]; + + let arguments = (ins + TT_DescriptorReduceKindAttr:$kind, + Arg]>:$desc, + Variadic:$coord, + Arg]>:$src + ); + + let assemblyFormat = [{ + $kind `,` $desc `[` $coord `]` $src + attr-dict `:` qualified(type($desc)) `,` qualified(type($src)) + }]; + let hasVerifier = 1; +} + +def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather"> { + let summary = "gather data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation gathers multiple rows of data from global memory matrix to + local memory asynchronously. This is similar to + async_tma_copy_global_to_local except that each row is indexed independently. + }]; + + let arguments = (ins + Arg]>:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + Arg]>:$barrier, + Arg]>:$result, + I1:$pred + ); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` $result `,` $barrier `,` $pred + attr-dict `:` type(operands) + }]; + + let hasVerifier = 1; +} + +def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> { + let summary = "scatter data from local memory into global memory based on a descriptor asynchronously"; + + let description = [{ + The `ttng.async_tma_scatter` operation scatters multiple separately-indexed + rows of data from local memory into global memory asynchronously. The + operation scatters a 2D tensor in shared memory, laid out by core tensor + tiles nvmma_shared layout into separately indexed rows in global + memory at a given `y` offset. + }]; + + let arguments = (ins + Arg, MemWrite]>:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + Arg]>:$src + ); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` $src + attr-dict `:` type(operands) + }]; + + let hasVerifier = 1; +} + +def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> { + let summary = "wait until all the inputs are read."; + let arguments = (ins I32Attr:$pendings); + let description = [{ + Wait until all the read operations are done from the associated store operations. + This is needed before the shared memory can be written to. + }]; + + let assemblyFormat = "attr-dict"; +} + +def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + AttrSizedOperandSegments +]> { + let summary = "block level op mapping to tensorcore gen5 mma"; + + let description = [{ + $d += matrix_multiply($a, $b). + if is_async is false, the op executes synchronously. The barrier operands must not be present in that case. + Otherwise, if a barrier is given, the op will trigger a commit/arrive on it. The result will be safe to read after a barrier wait. + If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs. + and syncronize both CTAs if the op is synchronous. + + This operation takes and produces an optional token to indicate TMEM read + and write on its accumulator operand. When the tokens are present, they can + be used to check aliasing and modref on the accumulator memory. + }]; + + let arguments = (ins + TTG_MemDescType:$a, + TTG_MemDescType:$b, + TTG_MemDescType:$d, + Optional:$acc_dep, + I1:$useD, + I1:$pred, + Variadic:$barriers, + Variadic:$barrier_preds, + UnitAttr:$is_async, + UnitAttr:$two_ctas, + UnitAttr:$multicast + ); + let results = (outs Optional:$token); + + let builders = [ + OpBuilder<(ins "Type":$token, + "Value":$a, "Value":$b, "Value":$d, "Value":$acc_dep, "Value":$useD, + "Value":$pred, CArg<"bool", "false">:$two_ctas, + CArg<"bool", "false">:$multicast, + CArg<"ValueRange", "{}">:$barriers, + CArg<"ValueRange", "{}">:$barrier_preds, + CArg<"bool", "false">:$is_async)> + ]; + + let assemblyFormat = [{ + $a `,` $b `,` $d `` custom($acc_dep, type($token)) `,` $useD`,` + $pred `` custom($barriers, $barrier_preds) + attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,` + qualified(type($d)) (`,` qualified(type($barriers))^)? + }]; + + let hasVerifier = 1; +} + +def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + AttrSizedOperandSegments +]> { + let summary = "block level op mapping to tensorcore gen5 mma"; + + let description = [{ + $d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale)) + if is_async is false, the op executes synchronously. The barrier operands must not be present in that case. + Otherwise, if a barrier is given, the op will trigger a commit/arrive on it. + The result will be safe to read after a barrier wait. + + This operation takes and produces an optional token to indicate TMEM read + and write on its accumulator operand. When the tokens are present, they can + be used to check aliasing and modref on the accumulator memory. + }]; + + let arguments = (ins + TTG_MemDescType:$a, + TTG_MemDescType:$b, + TTG_MemDescType:$d, + Optional:$acc_dep, + TTG_MemDescType:$a_scale, + TTG_MemDescType:$b_scale, + TT_ScaleDotElemTypeAttr:$a_type, + TT_ScaleDotElemTypeAttr:$b_type, + I1:$useD, + I1:$pred, + Variadic:$barriers, + Variadic:$barrier_preds, + UnitAttr:$is_async + ); + let results = (outs Optional:$token); + + let extraClassDeclaration = [{ + int64_t getBlockM(); + int64_t getBlockN(); + int64_t getBlockK(); + }]; + + let builders = [ + // Namespaces need to be prefixed so ODS prefers our + // custom builder signature over the default-generated one. + OpBuilder<(ins "::mlir::Type":$token, + "::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d, + "::mlir::Value":$acc_dep, "::mlir::Value":$a_scale, + "::mlir::Value":$b_scale, "::mlir::triton::ScaleDotElemType":$a_type, + "::mlir::triton::ScaleDotElemType":$b_type, + "::mlir::Value":$useD, "::mlir::Value":$pred, + CArg<"::mlir::ValueRange", "{}">:$barriers, + CArg<"::mlir::ValueRange", "{}">:$barrier_preds, + CArg<"bool", "false">:$is_async)> + ]; + + let assemblyFormat = [{ + $a `,` $b `,` $d `` custom($acc_dep, type($token)) `,` $a_scale `,` + $b_scale `,` $useD `,` $pred `lhs` `=` $a_type `rhs` `=` $b_type + `` custom($barriers, $barrier_preds) + attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,` + qualified(type($d)) `,` qualified(type($a_scale)) `,` + qualified(type($b_scale)) (`,` qualified(type($barriers))^)? + }]; + + let hasVerifier = 1; +} + +def TTNG_TCGen5CommitOp : TTNG_Op<"tc_gen5_commit", [AttrSizedOperandSegments]> { + let summary = "make an mbarrier track completion of all prior async tcgen5 ops"; + + let description = [{ + The `ttng.tc_gen5_commit` is an asynchronous operation that makes the + mbarrier object track the completion of all prior asynchronous tcgen5 + operations. Upon completion of all asynchronous operations, the mbarrier + arrive operation is performed on the mbarrier with a count of 1. + + If `descs` are provided, the commit will be multicast across the CTA cluster + based on the shared layouts of those descriptors. This should be used when + the inputs to the tcgen5 MMA come from TMA descriptors using multicast. + + Note that the completion mechanisms are guaranteed to occur sequentially in + the order the commit operations were issued. This means, for example: + + ```mlir + ttng.tmem_copy + ttng.tc_gen5_mma + ttng.tc_gen5_commit %barrierA + ttng.tc_gen5_commit %barrierB + ``` + + `%barrierA` tracks the completion of the previous TMEM copy and MMA + operations, but since the commit groups are sequential, the arrive-on + operation on `%barrierA` is guaranteed to be performed before the arrive-on + operation on `%barrierB`, even though its commit group is empty. + }]; + + let arguments = (ins + Arg]>:$barrier, + Optional:$pred, + Variadic:$descs + ); + + let assemblyFormat = [{ + $barrier (`,` $pred^)? (`descs` $descs^)? attr-dict `:` + qualified(type($barrier)) (`,` qualified(type($descs))^)? + }]; + + let hasVerifier = 1; +} + +def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load", [AttrSizedResultSegments]> { + let summary = "Load a buffer from tensor memory into a distributed tensor"; + + let description = [{ + This is similar to ttg.local_load except the result layout is restricted to only few possibility. + Therefore we cannot combine this op with any convert layout like local_load. + + This operation takes and produces an optional token to indicate TMEM read + on its source operand. When the tokens are present, they can + be used to check aliasing and modref on the TMEM buffer. + + Optional reduction modifier: + When `redOp` is specified, the load operation additionally performs an + element-wise reduction along the N-dimension of the input and produces a + second result tensor `red`. For a input of shape `[M, N]`, the + reduced result has shape `[M]`, containing one reduced value per "slice" + of the N-dimension. + + Currently restricted to f32 element type. + + - redOp: Specifies the reduction operation (MIN or MAX) to apply along + the N-dimension. When set, the `red` result must be present. + - abs: When true, applies absolute value to each element before performing + the reduction. Only valid when `redOp` is specified. + - NaN: When true, the reduction propagates NaN values (if any input element + in a slice is NaN, the corresponding reduced value is NaN). + When false, NaN values are ignored during reduction. + Only valid when `redOp` is specified. + + Example: + Input in TMEM of shape[M=2, N=4]: + [[ 1.0, 3.0, 2.0, 4.0], + [-5.0, 1.0, 8.0, 2.0]] + + With redOp=MAX: + result = [[ 1.0, 3.0, 2.0, 4.0], // unchanged + [-5.0, 1.0, 8.0, 2.0]] + red = [4.0, 8.0] // max along N per row + + With redOp=MIN, abs=true: + red = [1.0, 1.0] // min of |values| per row + + This operation lowers to hardware-accelerated reduction via the PTX + tcgen05.ld.red instruction on supported architectures, e.g. Blackwell Ultra. + }]; + let arguments = (ins + Arg]>:$src, + Optional:$dep, + OptionalAttr:$redOp, + OptionalAttr:$abs, + OptionalAttr:$NaN + ); + let results = (outs + TT_Tensor:$result, + Optional:$token, + Optional:$red + ); + + let assemblyFormat = [{ + $src `` custom($dep, type($token)) + attr-dict `:` qualified(type($src)) `->` type($result) (`,` type($red)^)? + }]; + + let builders = [ + // Basic builder: result type, optional token type, src, optional dep + OpBuilder<(ins "Type":$result, "Type":$token, "Value":$src, "Value":$dep), [{ + build($_builder, $_state, result, token, /*red=*/Type(), src, dep, + /*redOp=*/nullptr, /*abs=*/nullptr, /*NaN=*/nullptr); + }]>, + // Builder without token + OpBuilder<(ins "Type":$result, "Value":$src), [{ + build($_builder, $_state, result, /*token=*/Type(), /*red=*/Type(), src, + /*dep=*/Value(), /*redOp=*/nullptr, /*abs=*/nullptr, /*NaN=*/nullptr); + }]>, + // Builder with reduction - infers red type from result type + OpBuilder<(ins "Type":$result, "Type":$token, "Value":$src, "Value":$dep, + "::mlir::triton::nvidia_gpu::TMEMLoadReduceModifierAttr":$redOp, + "BoolAttr":$abs, "BoolAttr":$NaN), [{ + Type redTy; + if (redOp) { + auto tensorTy = ::mlir::cast(result); + SmallVector redShape = {tensorTy.getShape()[0]}; + auto parentEnc = ::mlir::cast<::mlir::triton::gpu::DistributedEncodingTrait>( + tensorTy.getEncoding()); + auto sliceEnc = ::mlir::triton::gpu::SliceEncodingAttr::get( + $_builder.getContext(), 1, parentEnc); + redTy = RankedTensorType::get(redShape, tensorTy.getElementType(), sliceEnc); + } + build($_builder, $_state, result, token, redTy, src, dep, redOp, abs, NaN); + }]>, + ]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + RankedTensorType getType() { return getResult().getType(); } + operator TypedValue() { return getResult(); } + }]; +} + +def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> { + let summary = "Store a distributed tensor into a buffer in tensor memory"; + + let description = [{ + This is similar to ttg.local_store except the source layout is restricted to only few possibility. + + This operation takes and produces an optional token to indicate TMEM write + on its source operand. When the tokens are present, they can + be used to check aliasing and modref on the TMEM buffer. + }]; + let arguments = (ins + Arg]>:$dst, + Optional:$dep, + TT_Tensor:$src, + I1:$pred + ); + let results = (outs Optional:$token); + + let builders = [ + OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$pred), [{ + build($_builder, $_state, Type(), dst, Value(), src, pred); + }]> + ]; + + let assemblyFormat = [{ + $src `,` $dst `` custom($dep, type($token)) `,` $pred + attr-dict `:` type($src) `->` qualified(type($dst)) + }]; + let hasVerifier = 1; +} + +def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor memory"; + let description = [{ + This operation allocates buffer in tensor memory and return a descriptor + containing the address and a view of the buffer. + This is similar to ttg.local_alloc except the buffer is allocated in tensor memory. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + let results = (outs + TTG_MemDescType:$result, + Optional:$token + ); + + let assemblyFormat = [{ + ($src^)? attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + triton::gpu::MemDescType getType() { return getResult().getType(); } + operator TypedValue() { return getResult(); } + }]; +} + +def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> { + let summary = "Take a subslice of a tensor memory allocation"; + let description = [{ + This operation takes a subslice of a tensor memory allocation and returns a new descriptor + containing the address and a view of the subslice. + This is similar to ttg.memdesc_subslice except we can only slice along the inner dimension + of a 2D memdesc as this is the only one we can do for TMem. + }]; + let arguments = (ins TTG_MemDescType:$src, I32Attr:$N); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` qualified(type($result)) + }]; + + let builders = [ + OpBuilder<(ins "Value":$alloc, "int":$offset, "int":$size)>, + ]; + let results = (outs TTG_MemDescType:$result); + let hasVerifier = 1; +} + +def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy"> { + let summary = "Initiate an asynchronous copy operation from shared memory to the Tensor Memory."; + + let description = [{ + 2D blocks stored contiguously in SMEM are copied into TMEM as specified by the destination address. + The completion of the copy can be observed by waiting on the optional barrier. If this op is used + together with an MMA op, one barrier can be used to wait for both copy and MMA. We do not need to wait + for the completion of the copy before MMA, since tcgen05.cp followed by tcgen05.mma is guaranteed to + execute in that order. + + This op lowers to the PTX instruction tcgen05.cp. This supports writing either to scales tmem layout as well as default tmem layout. + Currently the semantic is different when writing to tmem scale layout. + + In case of default layout the copy doesn't change the logical elements between the source and destination memdesc. + + In case of scale layout: + Each 32x128b block in SMEM is duplicated over 4 warps and stored into 128 rows + and 4 columns of TMEM. The primary use case of this op is to copy blocked scales from SMEM to TMEM. + + The shape of the input SMEM can be flexibily chosen depending on use cases. In the simplest case (e.g. unit test), + the source SMEM can be of shape (32 x num_blocks, 16), and the destination TMEM should be of shape (128, 16 x num_blocks), + for copying 8 bit values. For scaled GEMM, rep_m x rep_k copies of a 32x128b block need to be stored in SMEM, where + rep_m = BLOCK_M / 128, rep_k = BLOCK_K / scale_vec_size / 4, and scale_vec_size = 32 for MXFP. + Conceptually, the SMEM is organized in a high-dimensional layout, (rep_m, rep_k, 32, 4, 4B). + Some of axes can be flattened into one, to reduce the rank of the load. For example, the following patterns are supported: + * (rep_m, rep_k * 32 x 4 x 4B), 2D scale load with cp.async + * (rep_m, rep_k, 32, 16B), 4D scale load with TMA + * (rep_m, rep_k, 32, 4, 4B), 5D scale load with cp.async + Since rep_m blocks are not contiguous in SMEM, this axis cannot be flattened into inner ones. + + In Triton, the TMEM memdesc for blocked scales must be of the following form: + * Its shape must be (BLOCK_MN, BLOCK_K / scale_vec_size), representing the logical shape of blocked scales. + * It must be attached with `tensor_memory_scales_encoding` to indicate the chunk-based layout and its duplication over 4 warps. + + In contrast, the src SMEM must be in the explicit chunk-based layout as described above. So the IR might look like this: + + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tmem_copy %1, %0 : (!ttg.memdesc<1x1x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>) -> () + + We interpret the semantics of this copy operation as follows. The chunk-based layout in SMEM implies that + the logical shape (BLOCK_MN, BLOCK_K / scale_vec_size) in TMEM is the result of certain reshape and transpose operations. + In practice, to take an advantage of the native scale layout and the TMEM copy op, users need to do + `scales5D.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // scale_vec_size)` before feeding scales into dot_scaled. + When we use tmem_copy in the IR, such reshape and transpose operations are removed. But the change in the logical shape they have caused on + registers is now understood to be incorporated into tmem_copy itself. Ideally, we would lift reshape / transpose done on registers onto + the SMEM memdesc, making tmem_copy a straightforward 2D copy operation: (BLOCK_MN, BLOCK_K / scale_vec_size) -> (BLOCK_MN, BLOCK_K / scale_vec_size). + In the absence of such operations on memdesc, we resort to implicitly encoding the reshape/transpose semantics in tmem_copy. + + }]; + let arguments = (ins + Arg]>:$src, + Arg]>:$dst, + Optional:$barrier + ); + + let assemblyFormat = [{$src `,` $dst (`,` $barrier^)? attr-dict `:` qualified(type(operands))}]; + let hasVerifier = 1; +} + +def TTNG_ReinterpretTensorDescOp : TTNG_Op<"reinterpret_tensor_descriptor", [Pure]> { + let summary = "Reinterpret a pointer as a tensor descriptor"; + + let description = [{ + This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. + Ideally, we can remove this once the APIs are fully fleshed out. + }]; + + let arguments = (ins TT_Ptr:$rawDesc); + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = [{ + $rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result)) + }]; +} + +def TTNG_TensormapCreateOp: TTNG_Op< + "tensormap_create", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, + ] +> { + let summary = "Create a new TMA descriptor on device"; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_PtrType:$global_address, + Variadic:$box_dim, + Variadic:$global_dim, + Variadic:$global_stride, + Variadic:$element_stride, + ConfinedAttr]>:$elem_type, + ConfinedAttr]>:$interleave_layout, + ConfinedAttr]>:$swizzle_mode, + ConfinedAttr]>:$fill_mode + ); + let extraClassDeclaration = [{ + int32_t getRank() { + return getBoxDim().size(); + } + }]; + let assemblyFormat = [{ + $desc_ptr `,` $global_address `,` + `[` $box_dim `]` `,` + `[` $global_dim `]` `,` + `[` $global_stride `]` `,` + `[` $element_stride `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TTNG_TensormapFenceproxyAcquireOp: TTNG_Op< + "tensormap_fenceproxy_acquire", + [MemoryEffects<[MemWrite]>] +> { + let summary = "Acquire fence on a tensormap object"; + let arguments = (ins TT_PtrType:$desc_ptr); + let assemblyFormat = [{ + $desc_ptr attr-dict `:` qualified(type($desc_ptr)) + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td new file mode 100644 index 0000000000..edf0f27908 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td @@ -0,0 +1,90 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_TYPES +#define TRITONNVIDIAGPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonNvidiaGPU Type Definitions +//===----------------------------------------------------------------------===// + +class TTNG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +//===----------------------------------------------------------------------===// +// TensorDescIm2ColType +//===----------------------------------------------------------------------===// + +def TTNG_TensorDescIm2ColType : TTNG_TypeDef<"TensorDescIm2Col", "tensordesc_im2col", + [TT_TensorDescInterface]> { + let summary = "Im2col tensor descriptor type for NVIDIA TMA operations"; + + let description = [{ + Tensor descriptor type for im2col (image-to-column) tensor memory access. + This is used for convolution-friendly access patterns with TMA on NVIDIA GPUs. + + Im2col mode transforms a multi-dimensional tensor into a 2D matrix format + suitable for matrix multiplication, which is commonly used in convolution + operations. + + Parameters: + - blockType: The shape and element type of the data block being accessed + + This type implements TensorDescInterface, sharing common operations with + the tiled TensorDescType in the base Triton dialect. + + See NVIDIA PTX documentation for im2col tensor mode: + https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode + }]; + + let parameters = (ins + "RankedTensorType":$blockType + ); + + let assemblyFormat = [{ + `<` $blockType `>` + }]; + + let builders = [ + // Builder with signedness for integer types + TypeBuilder<(ins + "RankedTensorType":$blockType, + "bool":$isSigned + ), [{ + if (auto intTy = llvm::dyn_cast(blockType.getElementType())) { + auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; + auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem); + blockType = blockType.clone(elemTy); + } + return Base::get($_ctxt, blockType); + }]> + ]; + + let genVerifyDecl = 1; +} + +#endif // TRITONNVIDIAGPU_TYPES diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..d4b5c097f4 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU) +add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h new file mode 100644 index 0000000000..b11a3f653e --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +std::unique_ptr createTritonNvidiaGPUPlanCTAPass(); + +#define GEN_PASS_DECL +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td new file mode 100644 index 0000000000..a41b2e8914 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -0,0 +1,187 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_PASSES +#define TRITONNVIDIAGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> { + let summary = "plan CTA"; + + let description = [{ + This pass computes and applies "optimized" CTA tilings to DotOp, ReduceOp + and StoreLikeOps operations. + }]; + + let constructor = "mlir::triton::nvidia_gpu::createTritonNvidiaGPUPlanCTAPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> { + let summary = "Insert fences across generic and async proxy."; + + let description = [{ + This pass is to insert memory fences to ensure that memory operations are + properly ordered across generic and async operations. + This pass inserts fences at optimized location. + There is a pass later to handle all the functional requirements + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonGPUProxyFenceInsertion : Pass<"triton-nvidia-gpu-proxy-fence-insertion", "mlir::ModuleOp"> { + let summary = "Insert fences across generic and async proxy"; + + let description = [{ + This pass is to insert memory fences to ensure that memory operations are + properly ordered across generic and async operations. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::ModuleOp"> { + let summary = "lower to TMA load/store operations"; + + let description = [{ + Lower Triton descriptor load to TMA load/store operations in TritonNvidiaGPUDialect. + }]; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> { + let summary = "Assign tensor memory allocation"; + + let description = [{ + Decide on tensor memory allocation and assign attributes to each allocation. + }]; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::ModuleOp"> { + let summary = "lower mma operations if needed"; + + let description = [{ + Lower MMA ops to prepare for conversion to LLVM. + }]; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonNvidiaGPUPromoteLHSToTMemPass : Pass<"tritongpu-promote-lhs-to-tmem", "mlir::ModuleOp"> { + let summary = "Promote LHS operand of MMAv5 op to Tensor Memory"; + + let description = [{ + Promote LHS operand of MMAv5 op to Tensor Memory. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonNvidiaGPUOptimizeDescriptorEncodingPass : Pass<"triton-nvidia-optimize-descriptor-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on tensor descriptor types"; + + let description = [{ + Set shared memory encoding on tensor descriptors, which decides the swizzling mode and message size of the tma descriptor. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-layouts", "mlir::ModuleOp"> { + let summary = "Optimize TMEM layouts."; + + let description = [{ + Optimize TMEM layouts by selecting a layouts to enable better subtiling, + reduction performance, etc. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonNvidiaGPUInterleaveTMemPass : Pass<"triton-nvidia-interleave-tmem", "mlir::ModuleOp"> { + let summary = "Interleave TMEM loads/stores."; + + let description = [{ + The `triton-nvidia-interleave-tmem` pass attempts to sink TMEM loads and + hoist TMEM stores, and potentially interleave them, to reduce register + pressure. + }]; +} + +def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-tokens", "mlir::ModuleOp"> { + let summary = "remove TMEM tokens"; + + let description = [{ + The `triton-nvidia-gpu-remove-tmem-tokens` pass removes TMEM memory + dependency tokens from the IR, after they are no longer needed. + }]; +} + +def TritonNvidiaGPUCheckMatmulTwoCTAPass : Pass<"triton-nvidia-check-matmul-two-cta", "mlir::ModuleOp"> { + let summary = "Verify consistent two_ctas usage across matmuls"; + + let description = [{ + Inspect all matmul operations and ensure they agree on the `two_ctas` + setting. Propagate the chosen value to the module so later lowering steps + can access it. Compilation fails if mixed configurations are detected. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h new file mode 100644 index 0000000000..7697be2746 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -0,0 +1,57 @@ +#pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "llvm/Support/Casting.h" + +namespace mlir::triton::nvidia_gpu { + +constexpr inline int TMA_SIZE_BYTES = 128; +constexpr inline int TMA_ALIGN = 128; + +inline bool isFp4Padded(Attribute encoding) { + auto mmaEnc = dyn_cast(encoding); + return mmaEnc && mmaEnc.getFp4Padded(); +} + +gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout, + ArrayRef shape); + +gpu::SharedEncodingTrait +updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding, + RankedTensorType tensorType); + +triton::gpu::SharedEncodingTrait +getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType, + Value desc); + +inline SmallVector getTMABlockShape(Attribute encoding, + ArrayRef shapePerCTA, + bool packedSize) { + auto mmaEnc = cast(encoding); + return triton::gpu::getTMABlockShape( + shapePerCTA, mmaEnc.getElementBitWidth(), mmaEnc.getSwizzlingByteWidth(), + mmaEnc.getFp4Padded(), mmaEnc.getTransposed(), packedSize); +} + +inline SmallVector getTMABlockShape(RankedTensorType ty, + bool packedSize) { + auto shapePerCTA = gpu::getShapePerCTA(ty); + return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize); +} + +inline SmallVector getTMABlockShape(triton::gpu::MemDescType ty, + bool packedSize) { + auto shapePerCTA = gpu::getShapePerCTA(ty); + return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize); +} + +FailureOr getTMASwizzleMode(Location loc, TensorDescType ty); +FailureOr getTMAElementType(Location loc, TensorDescType ty); + +LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op, + OpBuilder &builder); + +} // namespace mlir::triton::nvidia_gpu diff --git a/third_party/mthreads/include/triton/Target/CMakeLists.txt b/third_party/mthreads/include/triton/Target/CMakeLists.txt new file mode 100644 index 0000000000..39d31dc9b5 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000..1f6c1b3511 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 0000000000..87da907e14 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,18 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Target/LLVMIR/Passes.h.inc" + +// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 0000000000..854d753342 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,21 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; +} + +def LLVMDILocalVariable: Pass<"extract-variable-info", "mlir::ModuleOp"> { + let summary = "Pull out source variable info from Location to DILocalVariable"; + let description = [{ + This pass pulled out source vararible's debuginfo from LLVM IR dialect's Location + into LLVM's DILocalVariable and fused it into previous Location so it can be passed to LLVM IR later in debugging mode. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Tools/GenericSwizzling.h b/third_party/mthreads/include/triton/Tools/GenericSwizzling.h new file mode 100644 index 0000000000..e1b3b3e2cc --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/GenericSwizzling.h @@ -0,0 +1,56 @@ +#ifndef TRITON_GENERIC_SWIZZLING_H +#define TRITON_GENERIC_SWIZZLING_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +namespace mlir::triton { +class LinearLayout; +class TargetInfoBase; +} // namespace mlir::triton + +namespace mlir::triton::gpu { +// Store the lane indices that are used in the contiguous part +// of an operation and in the address part. +// The laneAddr part just represents the indices used in one wavefront +// For now we just represent tiles with full vectorisation, meaning +// ld.shared.b32.v4/st.shared.b32.v4 +// ldmatrix.v4 / stmatrix.v4 +// ldmatrix.trans.v4 / stmatrix.trans.v4 +struct LocalMemOpTile { + // If laneContig.size() < log2(128/bitwidth), we assume that + // the first log2(128/bitwidth) - laneContig.size() bases are registers + llvm::SmallVector laneContig; + // If laneAddr.size() < 3, we assume that the first + // 3 - laneAddr.size() bases are registers + llvm::SmallVector laneAddr; +}; + +// Given a set of possible instructions given by +// targetInfo.laneIdTiles(bitwidth) returns the optimal swizzling given these +// instructions and a pair of indices into the ldStTiles that's needed to lower +// this swizzling +std::pair> +optimalSwizzling(const LinearLayout &src, const LinearLayout &dst, + llvm::ArrayRef srcTiles, + llvm::ArrayRef dstTiles, int32_t bitwidth); + +LinearLayout optimalSwizzlingLdSt(const LinearLayout &src, + const LinearLayout &dst, int32_t bitwidth); + +std::pair bankConflictsLdSt(const LinearLayout &src, + const LinearLayout &dst, + const LinearLayout &smem, + int32_t bitwidth); + +int bankConflictsMemDesc(const LinearLayout ®, const LinearLayout &smem, + int32_t bitwidth); + +std::pair bankConflicts(llvm::ArrayRef tileSrc, + llvm::ArrayRef tileDst, + const LinearLayout &smem); +} // namespace mlir::triton::gpu + +#endif // TRITON_GENERIC_SWIZZLING_H diff --git a/third_party/mthreads/include/triton/Tools/LayoutUtils.h b/third_party/mthreads/include/triton/Tools/LayoutUtils.h new file mode 100644 index 0000000000..7ea612fb02 --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/LayoutUtils.h @@ -0,0 +1,190 @@ +#ifndef TRITON_TOOLS_LAYOUTUTILS_H +#define TRITON_TOOLS_LAYOUTUTILS_H + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton { +// Is the sublayout defined from dimNames to dimNames the identity? +// In particular, is the input and output size in these dimensions +// the same, and are the bases the identity? +bool squareSublayoutIsIdentity(const LinearLayout &ll, + ArrayRef dimNames); + +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two, +// which is the case for DistributedLayouts If broadcastRegisters is false, we +// remove any register that's larger than the desired shape. In the example +// above we would have +// L(register=1) = 4 +// L(register=2) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters = true); + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape); + +inline LinearLayout +ensureLayoutNotSmallerThan(const LinearLayout &layout, + const llvm::ArrayRef dimNames, + const llvm::ArrayRef shape) { + llvm::SmallDenseMap namedDims; + for (auto [dimName, length] : llvm::zip_equal(dimNames, shape)) + namedDims[dimName] = length; + assert(namedDims.size() == shape.size() && "duplicate dimension names given"); + return ensureLayoutNotSmallerThan(layout, namedDims); +} + +// Return a vector of the standard out dimension names for tensor layouts. These +// are "dim0", "dim1", etc. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank); + +// Return a vector of the standard out dimension name/value pairs, i.e. +// ("dim0", dstShape[0]), ("dim1", dstShape[1]), etc. +SmallVector> +standardOutDimPairs(MLIRContext *ctx, ArrayRef dstShape); + +// Return an identity mapping from `inDimName` to the standard out dimensions, +// with the dimensions sized according to the shape. The bases are sorted +// according to `order`, with the most minor dimension first. +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order); + +// Return a layout with the same in/out dimensions as `layout` but with all +// bases set to 0. +LinearLayout zerosLike(const LinearLayout &layout); + +// For a layout A with A.hasInDim(kReg), find a permutation of registers action +// such that action.apply(A) may be divisible by B +// It's not always true that the action returned by this function will +// allow us to divideLeft (resp. divideRight), but it is true that if it if +// there exists one, it is the one returned by this function. +std::optional regPermForDivide(const LinearLayout &A, + const LinearLayout &B, bool left); + +// For a layout A with A.hasInDim(kReg), find a permutation of registers action +// such that action.apply(A) has the broadcasted registers removed +ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout); + +std::pair +actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout, + uint64_t maskSpanOffsets); + +// For a layout A with A.hasInDim(kReg), repeat the values so that they have +// the same broadcasting as layout +SmallVector broadcastAs(const SmallVector &values, + const LinearLayout &layout); + +// Compute the supremum of two lists. +// Error out if the supremum does not exist (e.g. [a, b] and [b, a]). +// If the supremum is not unique, we return the first list first +// (e.g. [a, b], [a, c] -> [a, b, c]). +SmallVector supremum(const SmallVector &x, + const SmallVector &y); + +// Return a new layout reshaped to the given shape. +LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout, + ArrayRef shape); + +// Return a new layout with the dimensions transposed according to the given +// order. +LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef order); + +// Given a distributed into shmem layout, return the largest vectorisation +// that can be used to lower the layout via ld/st. +std::pair +largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth, + std::optional maybeMaxVecElems = std::nullopt); + +// Close cousin of doing zerosLike(tile) * divideLeft(cvt, tile) +// This one is a tad more general in the sense that it allows to divide +// cvt: +// - register=1 -> (0, 1) +// register=2 -> (8, 0) +// register=4 -> (0, 8) +// register=8 -> (0, 16) +// register=16 -> (0, 32) +// register=32 -> (0, 64) +// register=64 -> (16, 0) +// - lane=1 -> (0, 2) +// lane=2 -> (0, 4) +// lane=4 -> (1, 0) +// lane=8 -> (2, 0) +// lane=16 -> (4, 0) +// - warp=1 -> (32, 0) +// warp=2 -> (64, 0) +// - block is a size 1 dimension +// where out dims are: [row (size 128), col (size 128)] +// tile: +// - register=1 -> (0, 1) +// register=2 -> (8, 0) +// - lane=1 -> (0, 2) +// lane=2 -> (0, 4) +// lane=4 -> (1, 0) +// lane=8 -> (2, 0) +// lane=16 -> (4, 0) +// - warp=1 -> (32, 0) +// warp=2 -> (64, 0) +// where out dims are: [row (size 128), col (size 8)] +// which would not be possible to lower via the divideLeft approach as we +// cannot divide by the tile given the `register=64 -> (16, 0)` basis. +std::optional getReps(const LinearLayout &cvt, + const LinearLayout &tile); + +// Given a layout mapping onto dim0..dimn, remove a dimension `dim` +// and rename the rest as dim0..dimn-1 +LinearLayout removeStandardDim(const LinearLayout &layout, int dim); +} // namespace mlir::triton + +#endif // TRITON_TOOLS_LAYOUTUTILS_H diff --git a/third_party/mthreads/include/triton/Tools/LinearLayout.h b/third_party/mthreads/include/triton/Tools/LinearLayout.h new file mode 100644 index 0000000000..5c1788816e --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/LinearLayout.h @@ -0,0 +1,904 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout (i.e., L) is the function that, given a "hardware location" tuple of +// (thread-id, warp-id), returns an index (x,y) into T. In other words, if +// L(t,w) = (x,y) is our linear layout func, then a register in thread t in warp +// w contains the value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs and similar operations, where the layout is logically +// flattened according to the dimension order and then chopped up again. +// +// ## Surjectivity and injectivity +// +// Most LLs are surjective, i.e. all output values are covered by some input +// value. But occasionally you might create a non-surjective layout, usually +// via invertAndCompose. We aggressively assert that LLs are surjective unless +// you explicitly create one that's not. +// +// LLs are not, in general, injective. There might exist multiple input values +// that map to the same output value. This represents the idea that the same +// logical tensor elements can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comparison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composable. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::MapVector outDims; + int32_t rank = 0; + +public: + using BasesT = decltype(bases); + + LinearLayout() = default; + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return {}; } + + // Creates a 1D -> 1D layout that's the function L(x) = stride * x + // for x in [0, size). + static LinearLayout strided1D(int32_t size, int32_t stride, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim) { + return strided1D(size, /*stride=*/1, inDim, outDim); + } + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). By default this creates a surjective layout where + // `outDim` has size 1 (the only element is 0). If `outDimSize` is specified + // to be greater than 1, then this creates a non-surjective layout with a + // specific size for `outDim`. + static LinearLayout zeros1D(int32_t size, StringAttr inDim, StringAttr outDim, + int32_t outDimSize = 1); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + // + // Calculates the out-dim sizes according to the bases. Consider the + // following example. + // + // L(in1=1) = (out1=1, out2=0) + // L(in1=2) = (out1=5, out2=1) + // L(in1=4) = (out1=2, out2=2) + // + // To calculate the out-dim sizes, we first find the largest values for out1 + // and out2, namely 5 and 2, then round these up to the next power of 2, + // namely 8 and 4. These are the out-dim sizes. + // + // Assert-fails if the layout is not surjective given these out-dim sizes. + // That is, every possible out-dim in range [0, size) must be produced by + // xor'ing some combination of bases. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Creates a LinearLayout given a list of bases and the explicit out-dimension + // sizes. Allows the layout to be non-surjective. + // + // To see why we need to explicitly pass out-dim sizes when creating a + // non-surjective layout, consider the following example. + // + // L(in1=1) = 1 + // L(in1=2) = 4 + // + // If we naively infer the out-dim sizes from these bases, we'd infer a size + // of nextPow2(4) = 8. But given that the layout is non-surjective, who is to + // say that the codomain is not (say) [0,32)? We can't tell, thus we need to + // be explicit about the sizes. + explicit LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + // + // The overload that infers out-dim sizes assert-fails if the layout is not + // surjective. + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective); + + bool isSurjective() const { return rank == getTotalOutDimSizeLog2(); } + bool isInjective() const { return rank == getTotalInDimSizeLog2(); } + + bool isInvertible() const { + return isSurjective() && getTotalInDimSize() == getTotalOutDimSize(); + } + + // Remove a dimension of size 1 from the layout. + [[nodiscard]] LinearLayout unsqueezeIn(StringAttr dim) const; + [[nodiscard]] LinearLayout unsqueezeOut(StringAttr dim) const; + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos >= 0); + assert(static_cast(pos) < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + auto getOutDimNames() const { return llvm::make_first_range(outDims); } + auto getOutDimSizes() const { return llvm::make_second_range(outDims); } + + // Relevant for reshaping + + SmallVector> getInDims() const { + SmallVector> inDims; + inDims.reserve(bases.size()); + for (auto [inDim, inDimBases] : bases) { + inDims.push_back({inDim, getInDimSize(inDim)}); + } + return inDims; + } + SmallVector> getOutDims() const { + return to_vector(outDims); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { return outDims.contains(outDim); } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDims.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + int32_t getTotalInDimSizeLog2() const; + int32_t getTotalInDimSize() const { return 1 << getTotalInDimSizeLog2(); } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s) (if the layout is surjective). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + int32_t getTotalOutDimSizeLog2() const; + int32_t getTotalOutDimSize() const { return 1 << getTotalOutDimSizeLog2(); } + + // Finds the number of consecutive input elements in the first input dimension + // that map to consecutive output elements in the first output dimension. + // + // Mathematically, finds the maximum value V such that for any a, b, c, and + // for all v in [0,V), + // + // L(a*V + v, b, c, ...) = L(a*V, b, c, ...) + (v, 0, ..., 0) + // + // Note that's +, not ⊕, in the RHS. (Equivalently, we could use binary-or + // instead of +. In other words, we require that L(a*V, b, c, ...) have no + // bits that overlap with v.) + // + // For example, if L maps (register, lane) to (dim1, dim0), then this tells + // you how many consecutive registers map to consecutive elements of dim1. + // + // This only works across the first (i.e. the most-minor) dimension of in/out. + // If you want it to work across more dimensions, flatten the layout. + // + // TODO(jlebar): Replace with divideLeft. + int32_t getNumConsecutiveInOut() const; + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + [[nodiscard]] LinearLayout reshapeIns( + ArrayRef> newInDims) + const; + + // Reshapes to a single input dim (named whatever our first in-dim is named). + [[nodiscard]] LinearLayout flattenIns() const { + if (getNumInDims() == 0) { + return reshapeIns({}); + } + return reshapeIns({{*getInDimNames().begin(), getTotalInDimSize()}}); + } + + [[nodiscard]] LinearLayout + reshapeOuts(ArrayRef> + newOutDims) const; + + // Reshapes to a single out dim (named whatever our first out-dim is named). + [[nodiscard]] LinearLayout flattenOuts() const { + if (getNumOutDims() == 0) { + return reshapeOuts({}); + } + return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}}); + } + + // Resizes the dimension to one that is smallre or equal to the given size. + // These operations are similar to `sublayout` but at a dimension level. + [[nodiscard]] LinearLayout resizeInDim(StringAttr inDim, + int32_t newSize) const; + [[nodiscard]] LinearLayout resizeOutDim(StringAttr outDim, + int32_t newSize) const; + + [[nodiscard]] LinearLayout renameInDim(StringAttr oldDim, + StringAttr newDim) const { + auto bases = getBases(); + auto it = bases.find(oldDim); + assert(it != bases.end()); + auto value = std::move(it->second); + bases.erase(it); + bases.insert({newDim, std::move(value)}); + return LinearLayout(std::move(bases), getOutDims(), + /*requireSurjective=*/isSurjective()); + } + + // Concatenates two layouts by their in (resp. out) dimensions. The layouts + // must have the same output (resp. input) dimensions and sizes and different + // input (resp. output) dimensions. The input dimensions of this layout are + // placed before those of 'other'. This can be thought of as the opposite of + // `sublayout`, which slices a layout from a larger one. + [[nodiscard]] LinearLayout concatIns(const LinearLayout &other) const; + [[nodiscard]] LinearLayout concatOuts(const LinearLayout &other) const; + + // Remove all the bases that equal to 0 for the given input dimension. + [[nodiscard]] LinearLayout unsqueezeIns(StringAttr dim) const; + + // Computes the direct sum of two layouts. + // https://en.wikipedia.org/wiki/Direct_sum#Direct_sum_of_matrices + // + // Roughly speaking, the first layout acts on the first part of the input + // dimensions, and the second layout acts on the second part. + // In other words, it's the generalisation of concatenation of the inputs + // to linear maps. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // The output matrix is [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // The output matrix is [[1, 0, 0], [0, 1, 0]] + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // The output matrix is [[0, 1, 0], [0, 0, 1]] + + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // The output dims are ("o1", "o2") in that order. + // + // If the input (or output) dims of the layouts are not the same, we take + // the supremum of the two ordered lists with the inclusion, respecting the + // order. If multiple suprema exist, we bias towards the first list. + // e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c] + // sup([a, b], [b, a]) = error! Supremum does not exist. + // + // Notice that this operation is not commutative, but it is associative. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + // + // Postcondition: If both inner and outer are surjective, the result is + // surjective. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Compute a C such that A = B * C if it exists. + // In other words, C = B^{-1} * A. + // For divideRight, we compute A = C * B, that is, C = A * B^{-1}. + // Note that such a C exists iff (every pair of input/output dim of) A is + // of the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + // + // C will always have the same input/output dimensions as A. + // When there are dimensions of size 1 there is some ambiguity in the + // division, as in `operator*` we treat missing dimensions as dimensions + // of size 1 whenever it makes sense to do so. The rule that C has the + // same dimensions as A ensures that C is well-defined. + friend std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B); + friend std::optional divideRight(const LinearLayout &A, + const LinearLayout &B); + + // Returns true if this layout acts trivially (as the identity) on the given + // dimensions. This means that it's the identity on those dimensions, and it + // does not map other dimensions onto those or these onto other dimensions. + bool isTrivialOver(ArrayRef dimNames) const; + + // For an endomorphism on dimNames (linear map that maps dimNames to dimNames) + // checks whether it is the identity map on these dimensions (i.e + // LinearLayouts::isTrivialOver) and if so, returns the sublayout of the + // remaining dimensions. + // nb. The isTrivialOver condition is more restrictive than the usual + // "leaves the subspace invariant" condition in maths. + // We can always relax it if we know how to take advantage of a conversion + // layout being block-diagonal in the future. + std::optional quotient(ArrayRef dimNames) const; + + // Gets a layout with only these in/out dimensions. + // + // In other words, gets a layout where the in-dims not mentioned in inDimNames + // are set to 0, and the out-dims not mentioned in outDimNames are omitted. + // + // The output-dim sizes are unchanged. The order of the in/out dims in the + // returned layout matches the order of the original layout, not the order of + // the arguments. + LinearLayout sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Is the sublayout restricted to inDimNames + outDimNames all zeros? + bool sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: + // - The output dimensions of this layout equal the input dimensions of + // outer (order doesn't matter). + // - For each output dim d of this layout, this->getOutDimSize(d) <= + // outer.getInDimSize(d). + // + // Postcondition: The result is surjective iff `this` and `outer` are + // surjective and this->getOutDimSize(d) == outer.getInDimSize(d) for each of + // this->getOutDimNames(). + // + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // Inverts or pseudo-inverts `outer` and composes it with `this`. + // + // Formally, if C = A.invertAndCompose(B), then for all x, C(x) = y implies + // A(x) = B(y), or in other words A(x) = B(C(x)). If B is invertible, then + // C(x) = B^-1(A(x)), which is how this function gets its name. + // + // For example, suppose you have the following two LLs. + // + // - R is an LL representing registers, mapping (lane, warp) to a 2D index. + // - S is an LL representing shared memory, mapping offset to a 2D index. + // + // Suppose you want to store tensor values from registers into shared memory. + // That is, given a (lane, warp), you want to know the corresponding shared + // memory offset to store into. + // + // This is equivalent to converting a (lane, warp) into a 2D index (i.e. + // applying R), then converting a 2D index into a shmem offset (i.e. applying + // the inverse of S). R.invertAndCompose(S) computes this transformation. + // + // Notice the following requirements in order for this to work. + // + // - R and S must have the same output dimension names (different order is + // allowed). + // - S must be surjective, i.e. there must be some offset for each output + // dimension of S. This way when we compose S^-1 with R, every possible + // 2D index that we might get from R has some shmem offset. + // - The codomain of S must be at least as large as the codomain of R. + // Otherwise, R could map some tensor index that is not stored in S. + // + // One requirement we *don't* have is that S is injective; we allow two shmem + // offsets to hold the same 2D index. If S is not injective, + // the algorithm chooses the smallest offset for a given (lane, warp). + [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + + // Get the layout that is the inverse of this layout. + [[nodiscard]] LinearLayout invert() const; + // Compute and return a psueodinverse of this layout. This is a layout such + // that `B = A.psuedoinvert()` implies that `A(B(x)) = I`. If `A` is + // invertible, then this returns `A^-1`. + [[nodiscard]] LinearLayout pseudoinvert() const; + + // For each in-dim, returns a bitmask of the "free variables" in the layout + // function. + // + // These are the bits in the input that can be changed without changing the + // output. If all of the free variables are 0, then the layout is injective + // (i.e. every input bit affects the output). + llvm::MapVector getFreeVariableMasks() const; + + // Take the current linear layout and remove all zero bases for the provided + // dimension and return the resulting layout. This is useful for deriving a + // layout that returns just the unique output values when varying a given + // input dimension that has broadcasting. + [[nodiscard]] LinearLayout removeZeroBasesAlongDim(StringAttr stripDim) const; + + std::string toString() const; + + friend bool operator==(const LinearLayout &lhs, const LinearLayout &rhs); + friend bool operator!=(const LinearLayout &lhs, const LinearLayout &rhs) { + return !(lhs == rhs); + } + bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + friend size_t hash_value(const LinearLayout &layout); + +private: + // Factory function that gracefully fails rather than asserts if the layout is + // not well-formed. + static std::optional + tryCreate(BasesT bases, ArrayRef> outDims, + bool requireSurjective); + + // Constructor that does not check invariants. Used by tryCreate. + struct NoCheckInvariants {}; + LinearLayout(BasesT bases, ArrayRef> outDims, + NoCheckInvariants); + + [[nodiscard]] std::optional + checkInvariants(bool requireSurjective); +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +// Defines a map acting on the columns (i.e. bases) a given input dimension of a +// layout as per: +// action[i] -> i. +// This action can be: +// - Applied to a layout to get a new layout with the same input dimensions +// but with the bases permuted (and perhaps some of them dropped). +// - Applied to a range of Values to apply the same transformation to them +// +// E.g. if action = [2, 0, 1] and basesDim = [1, 2, 4] +// - action.apply(layout) returns a LL with basesDim = [4, 1, 2] +// - action.apply(range) with range.size() == 8, returns a range permuted as +// [x[0], x[4], x[1], x[5], x[2], x[6], x[3], x[7]] +class ColumnAction { +private: + SmallVector action; + StringAttr inDim; + size_t inSizeLog2; + bool m_isIdentity = true; + +public: + ColumnAction() = default; + ColumnAction(ArrayRef action, StringAttr inDim, size_t inSizeLog2) + : action(action), inDim(inDim), inSizeLog2(inSizeLog2) { + auto it = llvm::max_element(action); + // Assert in the constructor... ugh + assert(it == action.end() || *it < inSizeLog2); + // In many cases the action will be the identity, so we save that as an + // early return + m_isIdentity = action.size() == inSizeLog2 && + llvm::equal(action, llvm::seq(action.size())); + } + + // Act on the columns of a layout + // Examples: + // - if action = [2, 0, 1] and layout.getBases()[inDim] = [[1], [2], [4]] + // - action.apply(layout) returns a LL with basesDim = [[4], [1], [2]] + // - if action = [2, 0] and layout.getBases()[inDim] = [[1], [4], [2]] + // - action.apply(layout) returns a LL with bases[inDim] = [[2], [1]] + LinearLayout apply(const LinearLayout &layout) const; + + // Act on a range of values (representing registers) + // e.g. if action = [2, 0, 1] and inSizeLog2 = 3 and inDim.str() = "register" + // - action.apply(range) with range.size() == 8, returns + // [x[0], x[4], x[1], x[5], x[2], x[6], x[3], x[7]] + SmallVector apply(ValueRange values) const; + + // Inverse of the action + ColumnAction inverse() const; + + // Given two permutations self, other seen as functions, returns + // ret(x) = other(self(x)) + ColumnAction leftCompose(const ColumnAction &other) const; + + static ColumnAction identity(StringAttr inDim, size_t inSizeLog2) { + return ColumnAction(llvm::to_vector(llvm::seq(inSizeLog2)), inDim, + inSizeLog2); + } + + // Returns true if the action is the identity + bool isIdentity() const { return m_isIdentity; } + + std::string toString() const; +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ColumnAction &action) { + os << action.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) { + os << action.toString(); + return os; +} + +std::unique_ptr getMatrix(const LinearLayout &layout); + +} // namespace mlir::triton + +#endif // TRITON_TOOLS_LINEARLAYOUT_H diff --git a/third_party/mthreads/include/triton/Tools/PluginUtils.h b/third_party/mthreads/include/triton/Tools/PluginUtils.h new file mode 100644 index 0000000000..5878af01bb --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/PluginUtils.h @@ -0,0 +1,100 @@ +#ifndef TRITON_PLUGIN_UTILS_H +#define TRITON_PLUGIN_UTILS_H + +#include "mlir/Pass/PassManager.h" +#include "mlir/Tools/Plugins/DialectPlugin.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include + +extern "C" { +enum TritonPluginResult { + TP_SUCCESS = 0, + TP_GENERIC_FAILURE = 1, +}; +}; +#define TRITON_PLUGIN_API \ + extern "C" __attribute__((visibility("default"))) TritonPluginResult +#define TRITON_PLUGIN_API_TYPE(_TYPE) \ + extern "C" __attribute__((visibility("default"))) _TYPE + +struct TritonPlugin { + TritonPlugin() = delete; + TritonPlugin(std::string filename) : filename(filename) {} + +public: + llvm::Error checkLibraryValid(const std::string &error) const; + static constexpr char ENUMERATE_PASSES[] = "tritonEnumeratePluginPasses"; + static constexpr char ENUMERATE_DIALECTS[] = "tritonEnumeratePluginDialects"; + static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo"; + static constexpr char ADD_PASS[] = "tritonAddPluginPass"; + static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass"; + +private: + using EnumeratePyBindHandlesType = + std::function; + using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, + const char **); + + using AddPassType = + std::function; + using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, + const char *); + + using RegisterPassType = std::function; + using RegisterPassCType = TritonPluginResult (*)(const char *); + + using DialectPluginInfoType = + std::function<::mlir::DialectPluginLibraryInfo(const char *)>; + using DialectPluginInfoCType = + ::mlir::DialectPluginLibraryInfo (*)(const char *); + + llvm::Expected getAddressOfSymbol(const std::string &symbol) const; + + template + llvm::Expected getAPI(const std::string &symbol) const { + llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); + if (auto Err = getDetailsFn.takeError()) { + return Err; + } + auto func = reinterpret_cast(*getDetailsFn); + return func; + } + + llvm::Expected checkAPIResult(TritonPluginResult result, + const char *handle) const; + llvm::Expected + enumeratePyBindHandles(EnumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &passNames); + +public: + std::runtime_error err2exp(llvm::Error Err); + + llvm::Error loadPlugin(); + + llvm::Expected + getPassHandles(std::vector &handles); + + llvm::Expected + getDialectHandles(std::vector &handles); + + llvm::Expected addPass(mlir::PassManager *pm, + const char *passHandle); + + llvm::Expected registerPass(const char *passHandle); + + llvm::Expected<::mlir::DialectPluginLibraryInfo> + getDialectPluginInfo(const char *dialectName); + +private: + std::string filename = ""; + mutable llvm::sys::DynamicLibrary library; + EnumeratePyBindHandlesType enumeratePassesAPI; + EnumeratePyBindHandlesType enumerateDialectsAPI; + AddPassType addPassAPI; + RegisterPassType registerPassAPI; + DialectPluginInfoType dialectPluginInfoAPI; + bool isLoaded = false; +}; + +#endif // TRITON_PLUGIN_UTILS_H diff --git a/third_party/mthreads/include/triton/Tools/StrUtil.h b/third_party/mthreads/include/triton/Tools/StrUtil.h new file mode 100644 index 0000000000..8b59f7d2b3 --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp b/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 0000000000..7ccbfee03a --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,124 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "AMDGCN_USE_BUFFER_ATOMICS", + "AMDGCN_USE_BUFFER_OPS", + "DISABLE_LLVM_OPT", + "DISABLE_SQMMA", + "DISABLE_WMMA", + "DISABLE_MMA_V3", + "DISABLE_MMA_V5", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "LLVM_PASS_PLUGIN_PATH", + "LLVM_EXTRACT_DI_LOCAL_VARIABLES", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_DUMP_PATH", + "MLIR_ENABLE_TIMING", + "MLIR_DISABLE_MULTITHREADING", + "TRITON_DEFAULT_FP_FUSION", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DUMP_MIR", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_HIP_USE_ASYNC_COPY", + "TRITON_HIP_USE_BLOCK_PINGPONG", + "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", + "TRITON_LLVM_DEBUG_ONLY", + "TRITON_ENABLE_ASAN", + "TRITON_OVERRIDE_ARCH", + "TRITON_MUSA_TOOLCHAIN_PATH", + "TRITON_MUSA_LLC_PATH", + "TRITON_MUSA_LLD_PATH", + "TRITON_MUSA_LLC_ASM_PATH", + "TRITON_MUSA_LLC_OPTIONS", + "TRITON_MUSA_ENABLE_LLC_OPT", + "TRITON_MUSA_ENABLE_FP8_BURST2", + "TRITON_MUSA_ENABLE_LLVM_COMPAT", + "TRITON_MUSA_DUMP_LLIR", + "TRITON_MUSA_DUMP_MUASM", + "TRITON_MUSA_REPLACE_LLIR", + "TRITON_MUSA_REPLACE_MUBIN", + "TRITON_MUSA_LIBDEVICE_PATH", + "USE_IR_LOC", + "NVPTX_ENABLE_DUMP", + "ALLOW_LHS_TMEM_LAYOUT_CONVERSION", + "TRITON_F32_DEFAULT", + "TRITON_PREFER_TMEM_16x256_LAYOUT", + "TRITON_ENABLE_EXPERIMENTAL_CONSAN", + "TRITON_PASS_PLUGIN_PATH", + "TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT", + "TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY", + "TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY", + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + // clang-format off + "TRITON_REPRODUCER_PATH", + "TRITON_ENABLE_PYTHON_STACKTRACE", + // clang-format on +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +static std::mutex getenv_mutex; + +inline std::string getStrEnv(const std::string &env) { + std::lock_guard lock(getenv_mutex); + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + std::lock_guard lock(getenv_mutex); + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/language/musa/__init__.py b/third_party/mthreads/language/musa/__init__.py new file mode 100644 index 0000000000..988a2710c6 --- /dev/null +++ b/third_party/mthreads/language/musa/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice, utils + +__all__ = ["libdevice", "utils"] diff --git a/third_party/mthreads/language/musa/libdevice.py b/third_party/mthreads/language/musa/libdevice.py new file mode 100644 index 0000000000..a10e528ba6 --- /dev/null +++ b/third_party/mthreads/language/musa/libdevice.py @@ -0,0 +1,1742 @@ +"""MUSA libdevice mappings. + +This module provides Triton externs bound to MUSA libdevice/intrinsics, +mirroring the CUDA libdevice API surface with MUSA-specific symbols. +""" + +from enum import Enum +from triton.language import core + + +class RoundingMode(Enum): + rn = 0 # rte + rz = 1 # rtz + rd = 2 # rtn + ru = 3 # rtp + reserve0 = 4 + reserve1 = 5 + + +@core.extern +def clz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def popc(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def byte_perm(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def mulhi(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_mulhi", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_umulhi", core.dtype("uint32")), + ( + core.dtype("int64"), + core.dtype("int64"), + ): ("__mt_mul64hi", core.dtype("int64")), + ( + core.dtype("uint64"), + core.dtype("uint64"), + ): ("__mt_umul64hi", core.dtype("uint64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def mul24(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_mul24", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_umul24", core.dtype("uint32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def brev(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sad(arg0, arg1, arg2, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + arg2, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + core.dtype("uint32"), + ): ("__mt_sad", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_usad", core.dtype("uint32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def abs(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_abs_i32", core.dtype("int32")), + (core.dtype("fp32"), ): ("__mt_fabs_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_fabs_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def floor(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_floor_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rcp64h(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def rsqrt(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.musa.rsq", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ceil(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def trunc(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__mt_trunc_f64", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__mt_trunc_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def exp2(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def saturatef(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fma_rn(arg0, arg1, arg2, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + arg2, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fmaf_rn_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def fma_rd(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def fma_ru(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_dividef(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fast_fdivide_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def div_rn(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rte_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rte_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def div_rz(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rtz_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def div_rd(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rtn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def div_ru(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rcp_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def rcp_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def rcp_rd(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def rcp_ru(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sqrt_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sqrt_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sqrt_rd(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sqrt_ru(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def sqrt(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def add_rn(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rte_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def add_rz(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def add_rd(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def add_ru(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def mul_rn(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rte_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def mul_rz(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def mul_rd(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def mul_ru(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2float_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2float_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2float_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2float_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2int_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2int_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2int_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2int_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2uint_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2uint_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2uint_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2uint_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def int2double_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_i32_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def uint2double_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint32"), ): ("__mt_ui32_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2int_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2int_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2int_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_i32_rd", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2int_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_i32_ru", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2uint_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2uint_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2uint_rd(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2uint_ru(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def int2float_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def int2float_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def int2float_rd(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def int2float_ru(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def uint2float_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def uint2float_rz(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def uint2float_rd(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def uint2float_ru(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def hiloint2double(arg0, arg1, _semantic=None): + raise NotImplementedError + + +@core.extern +def double2loint(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2hiint(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def float2ll_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ll_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ll_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ll_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ull_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rn", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ull_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rz", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ull_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rd", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float2ull_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_ru", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ll_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ll_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ll_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ll_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ull_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ull_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ull_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double2ull_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2float_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def ll2float_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_rz", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2float_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_rd", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2float_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_ru", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2float_rn(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def ull2float_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_rz", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2float_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_rd", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2float_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_ru", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2double_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rn], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2double_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rz], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2double_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rd], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ll2double_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.ru], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2double_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rn], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2double_rz(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rz], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2double_rd(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rd], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ull2double_ru(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.ru], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def int_as_float(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_int_as_float", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float_as_int(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_float_as_int", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def uint_as_float(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint32"), ): ("__mt_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def float_as_uint(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_float_as_uint", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def longlong_as_double(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def double_as_longlong(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fast_sinf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_cosf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_log2f(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_logf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_expf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_tanf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_exp10f(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_log10f(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def fast_powf(arg0, arg1, _semantic=None): + raise NotImplementedError + + +@core.extern +def hadd(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_hadd", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_uhadd", core.dtype("uint32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rhadd(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_rhadd", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_urhadd", core.dtype("uint32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sub_rn(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fsub_rn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rte_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sub_rz(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sub_rd(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sub_ru(arg0, arg1, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rsqrt_rn(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_rsqrt_rn_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ffs(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_ffs_i32", core.dtype("int32")), + (core.dtype("int64"), ): ("__mt_ffsll_i64", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rint(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def llrint(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def nearbyint(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def isnan(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__mt_isnan_f64", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def signbit(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_signbit_f32", core.dtype("int1")), + (core.dtype("fp64"), ): ("__mt_signbit_f64", core.dtype("int1")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def copysign(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def finitef(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isfinite_f32", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def isinf(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__mt_isinf_f64", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def nextafter(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_nextafter_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sin(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def cos(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sinpi(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sinpi_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sinpi_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def cospi(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_cospi_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def tan(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_tan_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def log2(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def exp(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def exp10(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_exp10_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def cosh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def sinh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def tanh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def atan2(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def atan(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def asin(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def acos(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def log(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def log10(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log10_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def log1p(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def acosh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def asinh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def atanh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def expm1(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def hypot(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rhypot(arg0, arg1, _semantic=None): + raise NotImplementedError + + +@core.extern +def norm3d(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _semantic=None): + raise NotImplementedError + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _semantic=None): + raise NotImplementedError + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _semantic=None): + raise NotImplementedError + + +@core.extern +def cbrt(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_cbrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_cbrt_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def rcbrt(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def j0(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def j1(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def y0(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def y1(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def yn(arg0, arg1, _semantic=None): + raise NotImplementedError + + +@core.extern +def jn(arg0, arg1, _semantic=None): + raise NotImplementedError + + +@core.extern +def cyl_bessel_i0(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def cyl_bessel_i1(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def erf(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erf_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def erfinv(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def erfc(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def erfcx(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def erfcinv(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfcinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfcinv_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def normcdfinv(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def normcdf(arg0, _semantic=None): + raise NotImplementedError + + +@core.extern +def lgamma(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ldexp(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__mt_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__mt_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def scalbn(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("int32"), + ): ("__mt_scalbn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("int32"), + ): ("__mt_scalbn_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fmod(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fmod_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def remainder(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_remainder_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_remainder_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fma(arg0, arg1, arg2, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def pow(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("int32"), + ): ("__mt_pown_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("int32"), + ): ("__mt_pown_f64", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_pow_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_pow_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def tgamma(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_tgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_tgamma_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def round(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_round_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_round_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def llround(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_llround_f32", core.dtype("int64")), + (core.dtype("fp64"), ): ("__mt_llround_f64", core.dtype("int64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fdim(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def ilogb(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def logb(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_logb_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_logb_f64", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def isfinited(arg0, _semantic=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_isfinite_f64", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fast_gelu(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_tt_gelu_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def fast_tanh(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_tt_tanh_f32", core.dtype("fp32")), + }, is_pure=True, _semantic=_semantic) diff --git a/third_party/mthreads/language/musa/utils.py b/third_party/mthreads/language/musa/utils.py new file mode 100644 index 0000000000..d685467955 --- /dev/null +++ b/third_party/mthreads/language/musa/utils.py @@ -0,0 +1,86 @@ +"""MUSA language utilities. + +Minimal helpers that are backend-agnostic but commonly used by kernels. +""" + +from triton.language import core + +_FP32_QNAN_BITS = 0x7FC00000 + + +@core.builtin +def num_threads(_semantic=None): + return core.constexpr(_semantic.builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_semantic=None): + return core.constexpr(_semantic.builder.options.num_warps) + + +@core.builtin +def _upcast_fp8_to_fp32(arg, exponent_bits, mantissa_bits, exponent_bias, nan_on_negzero, _semantic=None): + raw = arg.to(core.uint8, bitcast=True, _semantic=_semantic) + raw32 = raw.to(core.uint32, _semantic=_semantic) + + mantissa_mask = (1 << mantissa_bits) - 1 + exponent_mask = (1 << exponent_bits) - 1 + sign_shift = exponent_bits + mantissa_bits + + mantissa = raw32.__and__(mantissa_mask, _semantic=_semantic) + exponent = raw32.__rshift__(mantissa_bits, _semantic=_semantic) + exponent = exponent.__and__(exponent_mask, _semantic=_semantic) + sign = raw32.__rshift__(sign_shift, _semantic=_semantic) + + value_bits = sign.__lshift__(31, _semantic=_semantic) + value_bits = value_bits.__or__(exponent.__lshift__(23, _semantic=_semantic), _semantic=_semantic) + value_bits = value_bits.__or__(mantissa.__lshift__(23 - mantissa_bits, _semantic=_semantic), _semantic=_semantic) + value = value_bits.to(core.float32, bitcast=True, _semantic=_semantic) + value = core.mul(value, 2.0**(127 - exponent_bias), _semantic=_semantic) + + if nan_on_negzero: + nan_bits = core.full(raw.shape, _FP32_QNAN_BITS, core.uint32, _semantic=_semantic) + qnan = nan_bits.to(core.float32, bitcast=True, _semantic=_semantic) + is_negzero = raw.__eq__(0x80, _semantic=_semantic) + value = core.where(is_negzero, qnan, value, _semantic=_semantic) + + return value + + +def _upcast_fp32_to_dst(upcast, dst_ty, _semantic): + dst_scalar = dst_ty.scalar + if dst_scalar.is_fp32(): + return upcast + return upcast.to(dst_scalar, _semantic=_semantic) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding=None, _semantic=None): + src_ty = arg.type.scalar + dst_scalar = dst_ty.scalar + + if dst_scalar.is_fp8e4b15() or dst_scalar.is_fp8e4b8() or dst_scalar.is_fp8e5b16(): + raise ValueError(f"conversion to {dst_scalar} is not supported in this architecture") + + if src_ty.is_fp8e4b15(): + if not (dst_scalar.is_fp16() or dst_scalar.is_fp32()): + raise ValueError(f"conversion from {src_ty} to {dst_scalar} is not supported in this architecture") + upcast = _upcast_fp8_to_fp32(arg, exponent_bits=4, mantissa_bits=3, exponent_bias=15, nan_on_negzero=False, + _semantic=_semantic) + return _upcast_fp32_to_dst(upcast, dst_ty, _semantic) + + if src_ty.is_fp8e4b8(): + if not (dst_scalar.is_fp16() or dst_scalar.is_bf16() or dst_scalar.is_fp32()): + raise ValueError(f"conversion from {src_ty} to {dst_scalar} is not supported in this architecture") + upcast = _upcast_fp8_to_fp32(arg, exponent_bits=4, mantissa_bits=3, exponent_bias=8, nan_on_negzero=True, + _semantic=_semantic) + return _upcast_fp32_to_dst(upcast, dst_ty, _semantic) + + if src_ty.is_fp8e5b16(): + if not (dst_scalar.is_fp16() or dst_scalar.is_fp32()): + raise ValueError(f"conversion from {src_ty} to {dst_scalar} is not supported in this architecture") + upcast = _upcast_fp8_to_fp32(arg, exponent_bits=5, mantissa_bits=2, exponent_bias=16, nan_on_negzero=True, + _semantic=_semantic) + return _upcast_fp32_to_dst(upcast, dst_ty, _semantic) + + raise ValueError(f"unsupported custom fp8 conversion from {src_ty} to {dst_scalar}") diff --git a/third_party/mthreads/lib/Analysis/Alias.cpp b/third_party/mthreads/lib/Analysis/Alias.cpp new file mode 100644 index 0000000000..e2d0c90499 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Alias.cpp @@ -0,0 +1,72 @@ +#include "triton/Analysis/Alias.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +LogicalResult SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + auto result = op->getResult(0); + // skip ops that return memdesc in a different memory space. + if (auto memdescTy = dyn_cast(result.getType())) { + if (!isa_and_nonnull( + memdescTy.getMemorySpace())) + return success(); + } + + // Only LocalAllocOp creates a new buffer. + if (isa(op)) { + aliasInfo.insert(result); + pessimistic = false; + } else if (op->hasTrait()) { + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else if (isa(op)) { + aliasInfo = AliasInfo(); + pessimistic = false; + } else { + assert(!isa(result.getType()) && + "unknown operation creating memory descriptor"); + } + + if (pessimistic) { + setAllToEntryStates(results); + return success(); + } + // Join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); + + return success(); +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/Allocation.cpp b/third_party/mthreads/lib/Analysis/Allocation.cpp new file mode 100644 index 0000000000..84f13f216e --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Allocation.cpp @@ -0,0 +1,662 @@ +#include "triton/Analysis/Allocation.h" + +#include +#include + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "allocation-shared-memory" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { + +unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto *ctx = srcTy.getContext(); + auto srcLayout = gpu::toLinearLayout(srcTy); + auto dstLayout = gpu::toLinearLayout(dstTy); + srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout); + dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout); + auto bitwidth = getBitwidth(srcTy); + auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth); + auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps")); + return smem.getTotalOutDimSize() / reps; +} + +// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values +// because Triton's block-based programming model ensures that +// all threads sharing the same partition of the tensor see the same values, +// even for threads that do not participate in the atomic operation +static SmallVector getRepShapeForAtomic(Value result) { + SmallVector smemShape; + if (!result.use_empty()) { + if (auto tensorTy = dyn_cast(result.getType())) { + auto freeVariableMasks = + gpu::toLinearLayout(tensorTy).getFreeVariableMasks(); + if (llvm::any_of(freeVariableMasks, [](auto variableMask) { + return variableMask.second != 0; + })) { + // The tensor has broadcasted dimensions + smemShape = convertType(gpu::getShapePerCTA(tensorTy)); + } + } else { + // If the result is a scalar, we need to allocate a single element. + smemShape.push_back(1); + } + } + return smemShape; +} + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + return helper.getScratchSizeInBytes(); + } + if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + return helper.getScratchSizeInBytes(); + } + if (auto gatherOp = dyn_cast(op)) { + GatherLoweringHelper helper(gatherOp); + return helper.getScratchSizeInBytes(); + } + if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + return std::max(dstTy.getNumElements(), threadsPerWarp) * + getBitwidth(dstTy) / 8; + } + if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return 0; + // The generic pass uses swizzling + auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); + return elems * getBitwidth(srcTy) / 8; + } + if (isa(op)) { + auto value = op->getOperand(0); + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + if (elems == 0) + return 0; + auto elemTy = getElementTypeOrSelf(getPointeeType(value.getType())); + return elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + } + if (isa(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } + return 0; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation, + AllocationAnalysisScratchSizeFn scratchSizeGetter) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation), scratchSizeGetter(scratchSizeGetter) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + auto alloc = dyn_cast(op); + if (!alloc || !alloc.isSharedMemoryAlloc()) + return; + auto allocType = alloc.getType(); + int64_t numElems = 0; + if (auto paddedEnc = + dyn_cast(allocType.getEncoding())) { + SmallVector unpaddedShape = gpu::getShapePerCTA(allocType); + numElems = paddedEnc.getPaddedSize(unpaddedShape); + } else { + auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); + numElems = product(shapePerCTA); + } + int64_t bytes = + numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8; + + auto alignment = alloc.getAlignmentOrDefault(); + allocation->addBuffer(alloc, bytes, + alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + constexpr size_t scratchAlignment = 128; + if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + return; + } + if (auto ws = dyn_cast(op)) { + // `ttg.warp_specialize` needs memory to pass its explicit captures. Pack + // the captures like a struct. + auto [captureSize, captureAlign] = ws.getCaptureSizeAlign(); + maybeAddScratchBuffer(op, captureSize, + captureAlign); + return; + } + if (auto func = dyn_cast(op)) { + unsigned numWarpIndices = 0; + // Warp specialization communicates states over shared memory to each + // warp. Add space for an i8 for each warpgroup warp. + func.walk([&](gpu::WarpSpecializeOp op) { + numWarpIndices = std::max(numWarpIndices, op.getTotalPartitionWarps()); + }); + maybeAddScratchBuffer(op, numWarpIndices); + return; + } + unsigned bytes = scratchSizeGetter(op); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + LLVM_DEBUG({ + llvm::dbgs() << "-- buffer " << buffer->id << "; value: "; + value.dump(); + }); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (const auto &[value, buffers] : allocation->aliasBuffer) { + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto [op, buffer] : container) { + // Buffers owned by the function are assumed live for the whole + // function. This memory is used for warp specialization codegen. + // FIXME: Spooky-action-at-a-distance. Find a better way to model this. + if (op == operation) { + bufferRange.insert( + {buffer, Interval(size_t(), std::numeric_limits::max())}); + continue; + } + + // Any scratch memory's live range is the current operation's live + // range. + bufferRange.insert( + {buffer, Interval(operationId.at(op), operationId.at(op) + 1)}); + LLVM_DEBUG({ + llvm::dbgs() << "-- buffer " << buffer->id << "; value: "; + op->dump(); + }); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + llvm::for_each(liveOperations, [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + + // SQMMA local_alloc buffers are consumed together by dot instructions + // inside the loop body; keep them alive through the enclosing for-op so + // A/B tiles are not assigned overlapping shared-memory regions. + if (auto localAlloc = + dyn_cast_or_null(value.getDefiningOp())) { + bool isSqmma = localAlloc->hasAttr("sqmma.op_idx") || + localAlloc->hasAttr("sqmma.opIdx"); + if (isSqmma) { + if (auto forOp = localAlloc->getParentOfType()) + maxId = std::max(maxId, operationId[forOp.getOperation()] + 1); + } + } + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + void dumpBuffers() const { + LDBG("Dump bufferRange: id size offset ---------"); + for (auto bufferIter : bufferRange) { + llvm::dbgs() << "-- " << bufferIter.first->id << " " + << bufferIter.first->size << " " << bufferIter.first->offset; + llvm::dbgs() << " interval " << bufferIter.second.start() << " " + << bufferIter.second.end() << "\n"; + } + } + + void dumpAllocationSize() const { + LDBG("Dump shared memory allocation size -----------"); + auto liveBuffers = allocation->getLiveBuffers(); + auto analyzedSize = 0; + for (auto [op, bufferIds] : liveBuffers) { + auto size = 0; + for (auto bufferId : bufferIds) { + auto bufferSize = allocation->getAllocatedSize(bufferId); + size += bufferSize; + } + analyzedSize = std::max(analyzedSize, size); + } + llvm::dbgs() << "Allocated: " << allocation->sharedMemorySize + << ", analyzed: " << analyzedSize << "\n"; + } + + void dumpInterferenceGraph(const GraphT &interference) const { + LDBG("\n"); + LDBG("Dump interference graph: \n"); + for (auto edges : interference) { + llvm::dbgs() << "-- from " << edges.first->id << " to "; + for (auto node : edges.second) { + llvm::dbgs() << node->id << "; "; + } + llvm::dbgs() << "\n"; + } + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + // Sort buffers by size in descending order to reduce the fragmentation + // on big buffers caused by smaller buffers. Big buffers have a higher + // chance to overlap with multiple other buffers, and allocating them first + // (by calculateStarts) ensures a higher chance that they will occupy a + // standalone smem slot. + llvm::stable_sort( + buffers, [&](BufferT *A, BufferT *B) { return A->size > B->size; }); + + calculateStarts(buffers); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + + LLVM_DEBUG(dumpAllocationSize()); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (const auto &val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + LLVM_DEBUG(dumpBuffers()); + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + auto isSqmmaBuffer = [](BufferT *buffer) -> bool { + auto alloc = dyn_cast_or_null(buffer->owner); + if (!alloc || !alloc.isSharedMemoryAlloc()) + return false; + bool isSqmma = + alloc->hasAttr("sqmma.op_idx") || alloc->hasAttr("sqmma.opIdx"); + if (!isSqmma) + return false; + return true; + }; + auto getParentForOp = [](BufferT *buffer) -> scf::ForOp { + return buffer->owner ? buffer->owner->getParentOfType() + : scf::ForOp(); + }; + + // Reset interference graph + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + + // Buffers interfere if their allocation offsets overlap and they are + // live at the same time. + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + + // Buffers also interfere if their allocation offsets overlap and they + // exist within regions that may execute simultaneously with respect to + // each other. + auto wsx = x->owner->getParentWithTrait(); + auto wsy = y->owner->getParentWithTrait(); + if (wsx && wsy && wsx == wsy && + x->owner->getParentRegion() != y->owner->getParentRegion() && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + + // SQMMA local_alloc buffers in the same loop iteration are consumed + // together by dot operations and must not alias. Keep this protection + // scoped to a single loop; sequential loop/remainder regions can safely + // reuse shared memory through normal liveness. + if (isSqmmaBuffer(x) && isSqmmaBuffer(y) && + getParentForOp(x) == getParentForOp(y) && getParentForOp(x) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } + + LLVM_DEBUG(dumpInterferenceGraph(interference)); + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->sharedMemorySize = 0; + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + LLVM_DEBUG({ + llvm::dbgs() << "-- color " << x->id << " " << colors[x] << "\n"; + }); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); + } + LLVM_DEBUG(dumpBuffers()); + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; + AllocationAnalysisScratchSizeFn scratchSizeGetter; +}; + +} // namespace triton + +void Allocation::run( + FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this, + scratchSizeGetter); +} + +std::map> +Allocation::getLiveBuffers() { + std::map> liveBuffers; + + Operation *rootOperation = getOperation(); + Liveness liveness(rootOperation); + auto analyzeOperation = [&](Operation *op) -> void { + auto scratchBuffer = getBufferId(op); + if (scratchBuffer != InvalidBufferId) + liveBuffers[op].push_back(scratchBuffer); + for (auto result : op->getOpResults()) { + auto bufferId = getBufferId(result); + if (bufferId == Allocation::InvalidBufferId) + continue; + auto liveOperations = liveness.resolveLiveness(result); + for (auto depOp : liveOperations) + liveBuffers[depOp].push_back(bufferId); + } + }; + rootOperation->walk(analyzeOperation); + return liveBuffers; +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/AxisInfo.cpp b/third_party/mthreads/lib/Analysis/AxisInfo.cpp new file mode 100644 index 0000000000..6f9437cf67 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1422 @@ +#include "triton/Analysis/AxisInfo.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { +namespace { + +constexpr int64_t kMaxDivisor = highestPowOf2Divisor(0); + +template int64_t gcd(int64_t a, int64_t b, Args... args) { + if (a == 0) + return b; + if (b == 0) + return a; + if constexpr (sizeof...(args) == 0) + return std::gcd(a, b); + else + return gcd(std::gcd(a, b), args...); +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + if (lhs > kMaxDivisor / rhs) + return kMaxDivisor; + return lhs * rhs; +} + +int64_t getDivisibilityFromContiguity(const AxisInfo &lhs, const AxisInfo &rhs, + int d) { + // For example if we have the following two arrays using the selectOp: + // lhs: [[0, 1], [4, 5]] + // rhs: [[16, 17, 18, 19]] + // The resulting contiguity will be 2, while the divisibility will be 2 + // because 18 is not divisible by 4. + if (lhs.getContiguity(d) == rhs.getContiguity(d) || + lhs.getContiguity(d) == kMaxDivisor || + rhs.getContiguity(d) == kMaxDivisor) { + // Contiguity not changed or one of them is unresolved. + // If unresolved, we can first perform a loose bound gcd since the unknown + // contiguity will be resolved in the end. + return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); + } else { + // Contiguity changed, we cannot use only divisibility. + return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d), + lhs.getContiguity(d), rhs.getContiguity(d)); + } +} + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(isa(op.getType()) || + rank == 1 && "Expected ranked tensor or scalar"); + assert(operands.size() == 2 && "Expected two operands"); + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT contiguity(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (auto d = 0; d < rank; ++d) { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AxisInfo::getPessimisticValueState(lattice->getAnchor()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver, + axisinfo::CallbackType callback = nullptr); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class UnrealizedConversionCastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl< + mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(mlir::UnrealizedConversionCastOp op, + ArrayRef *> operands) override { + auto tensorType = dyn_cast(op.getResultTypes()[0]); + if (tensorType && + tensorType.getRank() != operands[0]->getValue().getRank()) { + // Do not propagate AxisInfo with incorrect rank. This can cause a crash + // in future visitor applications. + return AxisInfo::getPessimisticValueState(op->getResult(0)); + } + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +class ConstantOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(arith::ConstantOp op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + auto shapedTy = dyn_cast(splatAttr.getType()); + if (!shapedTy || !shapedTy.hasRank()) + return AxisInfo(); + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(shapedTy.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(shapedTy.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(shapedTy.getShape().begin(), + shapedTy.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(ub::PoisonOp op, + ArrayRef *> operands) override { + unsigned rank = 1; + if (auto shape = dyn_cast(op.getType())) + rank = shape.getRank(); + + // Poison values are never accessed, thus assume optimistic values. + return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor)); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (isa(op)) { + // Case 1: If contiguity(lhs) > 1 and contiguity(rhs) > 1, + // x_t - y_t = (base_x + t) - (base_y + t) = base_x - base_y for any + // 0 <= t < min(contig_x, contig_y), so contiguity is 1. + // Case 2: If contiguity(lhs) > 1 and contiguity(rhs) == 1, + // x_t - y = (base_x + t) - base_y = base_x - base_y + t for any + // 0 <= t < contig_x, + // the contiguity depends on the constancy of rhs. + // Case 3: If contiguity(lhs) == 1 and contiguity(rhs) > 1, + // x - y_t = base_x - (base_y + t) = base_x - base_y - t for any + // 0 <= t < contig_y. The result is decreasing within the contiguous + // block, so contiguity is 1. + // Case 4: If contiguity(lhs) == 1 and contiguity(rhs) == 1, + // x - y = base_x - base_y, so contiguity is 1. + return gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)); + } + // For AddIOp and AddPtrOp + // Case 1: If contiguity(lhs) > 1 and contiguity(rhs) > 1, + // x_t + y_t = (base_x + t) + (base_y + t) = base_x + base_y + 2t for any + // 0 <= t < min(contig_x, contig_y), + // so contiguity is 1. + // Case 2: If contiguity(lhs) > 1 and contiguity(rhs) == 1, + // x_t + y = (base_x + t) + base_y = base_x + base_y + t for any + // 0 <= t < contig_x, so contiguity depends on constancy of rhs. + // Case 3: If contiguity(lhs) == 1 and contiguity(rhs) > 1, + // It's symmetric to case B. + // Case 4: If contiguity(lhs) == 1 and contiguity(rhs) == 1, + // It's trivial that contiguity is 1 + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + int64_t elemSize = 1; + auto lhsDivisibility = lhs.getDivisibility(dim); + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisibility of 16 bytes + elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + if (lhs.getContiguity(dim) > 1 && rhs.getContiguity(dim) > 1) { + // If both operands are contiguous, the in-group offsets are: + // Let lhs_t = base_lhs + t and rhs_t = base_rhs + t for any + // 0 <= t < min(contig_lhs, contig_rhs). + // For addition: + // lhs_t + rhs_t = base_lhs + base_rhs + 2t + // For subtraction: + // lhs_t - rhs_t = base_lhs - base_rhs + if constexpr (std::is_same_v) { + if (lhs.getContiguity(dim) == rhs.getContiguity(dim)) + return gcd(lhsDivisibility, rhsDivisibility); + } + if ((lhsDivisibility % 2 == 0 && rhsDivisibility % 2 == 0)) { + // Both even -> result divisible by 2. + return 2; + } else { + // At least one is odd -> the "lower bound" of divisibility is 1. + return 1; + } + } else { + // At least one operand is partially constant. + // Divisibility is defined on the *first element* of a contiguity + // group. When an operand has contiguity larger than the result + // contiguity, the "first element of a result group" can fall inside an + // operand's contiguity group, so we must clamp the operand divisibility + // accordingly (otherwise we can overestimate alignment). + if (lhs.getContiguity(dim) > 1 || rhs.getContiguity(dim) > 1) { + auto resContiguity = getContiguity(op, lhs, rhs, dim); + return gcd(lhsDivisibility, rhsDivisibility, + multiplyDivisor(resContiguity, elemSize)); + } else { + return gcd(lhsDivisibility, rhsDivisibility); + } + } + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && rhs.getConstantValue() != 1) { + // If the operand is contiguous, the divisibility of the + // sequence drops to 1. + // Example: [4, 5, 6, 7] (base 4 divisible by 4). + // Multiplying by 2 yields [8, 10, 12, 14] (GCD=2). + // Preserving divisibility=4 implies result align 8 (unsafe). + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && lhs.getConstantValue() != 1) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + auto lhsConst = lhs.getConstantValue(); + auto rhsConst = rhs.getConstantValue(); + if (lhsConst.has_value() && rhsConst.has_value()) + return {lhsConst.value() * rhsConst.value()}; + if ((lhsConst.has_value() && lhsConst.value() == 0) || + (rhsConst.has_value() && rhsConst.value() == 0)) + return 0; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + if (!resTy) + return constancy; + auto shape = resTy.getShape(); + // Case: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, + gcd(lhs.getContiguity(dim), lhs.getDivisibility(dim), + rhs.getDivisibility(dim))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // Case 3: lhs has contiguity of 1 in this dimension and rhs is a power of 2 + if (rhs.getConstantValue().has_value() && + llvm::isPowerOf2_64(std::abs(rhs.getConstantValue().value())) && + lhs.getContiguity(dim) == 1) { + int64_t absRhs = std::abs(rhs.getConstantValue().value()); + return std::max(1, lhs.getDivisibility(dim) / absRhs); + } + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = gcd(lhs.getContiguity(dim), lhs.getDivisibility(dim), + rhs.getDivisibility(dim)); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (rhs.getConstancy(dim) > 1) { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + } + // Otherwise we shouldn't assume any divisibility. + // For example: + // lhs: [2, 2, 4, 4], rhs: [0, 1, 2, 3] + // lhs % rhs = [0, 0, 0, 1] + return 1; + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return constancy; + // Case: lhs % 1 = 0 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return resTy.getDimSize(dim); + return constancy; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = shape[d]; + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + getDivisibilityFromContiguity(lhsInfo, rhsInfo, d)); + contiguity.push_back( + gcd(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back(gcd(lhsInfo.getConstancy(d), + rhsInfo.getConstancy(d), condConstancy[d])); + contiguity.push_back(gcd(lhsInfo.getContiguity(d), + rhsInfo.getContiguity(d), condConstancy[d])); + divisibility.push_back( + getDivisibilityFromContiguity(lhsInfo, rhsInfo, d)); + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + } + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().value_or(0); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, 1ll << shift); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (!rhs.getConstantValue().has_value()) + return 1; + auto shift = rhs.getConstantValue().value(); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (int64_t(1) << shift)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/divisibility, + /*knownConstancy=*/constancy, + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + getDivisibilityFromContiguity(lhsInfo, rhsInfo, d)); + contiguity.push_back( + gcd(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +class TransOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::TransOp op, + ArrayRef *> operands) override { + AxisInfo srcInfo = operands[0]->getValue(); + auto order = op.getOrder(); + auto rank = srcInfo.getRank(); + + // Apply the transpose permutation to all axis info properties + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < rank; ++d) { + int srcDim = order[d]; + contiguity.push_back(srcInfo.getContiguity(srcDim)); + divisibility.push_back(srcInfo.getDivisibility(srcDim)); + constancy.push_back(srcInfo.getConstancy(srcDim)); + } + + return AxisInfo(contiguity, divisibility, constancy, + srcInfo.getConstantValue()); + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, + axisinfo::CallbackType callback) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append(); + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + + if (callback) + callback(visitors); +} + +LogicalResult AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // If any operands are not yet ready, skip this operation for now. + for (auto op : operands) + if (op->getValue().getRank() == 0) + return success(); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return success(); + } + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.contiguity"), + &newContiguity); + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"), + &newDivisibility); + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.constancy"), + &newConstancy); + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); + return success(); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + ProgramPoint *programPoint = getProgramPointAfter(op); + auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound()); + auto *stepLattice = getLatticeElementFor(programPoint, op.getStep()); + // If lb or step is not yet ready, skip this operation for now. + if (lbLattice->getValue().getRank() == 0 || + stepLattice->getValue().getRank() == 0) { + return; + } + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lbLattice->getValue().getDivisibility(0), + stepLattice->getValue().getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +void AxisInfo::initPessimisticStateFromFunc(int argNumber, + FunctionOpInterface funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // list of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + AxisInfo::initDimVectorFromHint(attr, vec); + } +} + +void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(1, int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) { + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } + } else if (Operation *op = value.getDefiningOp()) { + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"), + &knownDivisibility); + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.contiguity"), + &knownContiguity); + AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.constancy"), + &knownConstancy); + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks"); + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(getDivisibilityFromContiguity(lhs, rhs, d)); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getContiguity(Value value) { + auto tensorTy = dyn_cast(value.getType()); + if (!tensorTy) + return 1; + auto elemTy = tensorTy.getElementType(); + // Get the pointee type if we have a tensor of ptrs to compute contiguity for + if (auto ptrTy = dyn_cast(elemTy)) { + elemTy = ptrTy.getPointeeType(); + } + return getContiguity(value, elemTy.getIntOrFloatBitWidth()); +} + +unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue, + unsigned elementBitWidth) { + // FIXME: This is not as good as it could be, as we don't need to restrict + // the analysis to one dimension. We should determine contiguity on the + // flattenOuts() layout + auto tensorTy = cast(offsetsValue.getType()); + auto linAttr = gpu::toLinearEncoding(tensorTy); + auto order = linAttr.getOrder(); + unsigned align = getAlignment(offsetsValue, elementBitWidth); + + auto uniqueContigPerThread = linAttr.getContigPerThread(); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getAlignment(Value value) { + auto tensorTy = dyn_cast(value.getType()); + if (!tensorTy) + return 1; + + auto elemTy = tensorTy.getElementType(); + // Get the pointee type if we have a tensor of ptrs to compute contiguity for + if (auto ptrTy = dyn_cast(elemTy)) { + elemTy = ptrTy.getPointeeType(); + } + return getAlignment(value, elemTy.getIntOrFloatBitWidth()); +} + +unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue, + unsigned elementBitWidth) { + auto tensorTy = cast(offsetsValue.getType()); + auto *axisInfo = getAxisInfo(offsetsValue); + if (!axisInfo) + return 1; + auto linAttr = gpu::toLinearEncoding(tensorTy); + auto order = linAttr.getOrder(); + + auto divisibility = axisInfo->getDivisibility(order[0]); + auto elemNumBytes = std::max(elementBitWidth / 8, 1); + auto elemTy = tensorTy.getElementType(); + auto maxMultiple = isa(elemTy) + ? std::max(divisibility / elemNumBytes, 1) + : divisibility; + + auto maxContig = axisInfo->getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getAlignment order[0] " << order[0] << " maxContig = " << maxContig + << " elemNumBits = " << elementBitWidth + << " maxMultiple = " << maxMultiple + << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto linAttr = gpu::toLinearEncoding(tensorTy); + auto maskOrder = linAttr.getOrder(); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, + axisinfo::CallbackType callback) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(callback); + if (failed(solver->initializeAndRun(funcOp))) + return; + + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + // If we could not determine the AxisInfo for this value, assume the + // pessimistic state. + if (axisInfo.getRank() == 0) + axisInfo = AxisInfo::getPessimisticValueState(value); + auto &valInfo = (*axisInfoMap)[value]; + valInfo = AxisInfo::join(axisInfo, valInfo); + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = kMaxDivisor; + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + // Only scalar arguments are supported. Do not forward multi-dimensional + // AxisInfo to the callee. + if (axisInfo.getRank() != 1) + continue; + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Analysis/BufferRegion.cpp b/third_party/mthreads/lib/Analysis/BufferRegion.cpp new file mode 100644 index 0000000000..f205c287e0 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/BufferRegion.cpp @@ -0,0 +1,360 @@ +#include "triton/Analysis/BufferRegion.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" + +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +using namespace mlir; + +namespace { +// TODO: move to Utility.cpp/unify with TritonInstrument/Utility.cpp +uint64_t getAllocationOffset(ttg::LocalAllocOp op) { + auto offsetAttr = op->getAttr("allocation.offset"); + if (!offsetAttr) { + llvm::report_fatal_error( + "ConcurrencySanitizer should run after AllocateSharedMemory pass."); + } + return cast(offsetAttr).getInt(); +} + +uint64_t getAllocationOffset(ttng::TMEMAllocOp op) { + auto colOffsetAttr = op->getAttr("tensor_memory_col_offset"); + auto rowOffsetAttr = op->getAttr("tensor_memory_row_offset"); + if (!colOffsetAttr || !rowOffsetAttr) { + llvm::report_fatal_error( + "ConcurrencySanitizer should run after AllocateSharedMemory and " + "TensorMemoryAllocation pass."); + } + int colOffset = cast(colOffsetAttr).getInt(); + int rowOffset = cast(rowOffsetAttr).getInt(); + return colOffset | (rowOffset << 16); +} + +unsigned getMemDescSize(ttg::MemDescType ty) { + if (isa(ty.getMemorySpace())) { + return ttng::getTmemAllocSizes(ty).numCols; + } + assert(isa(ty.getMemorySpace()) && + "Unsupported memory space"); + unsigned elSize = ty.getElementType().getIntOrFloatBitWidth() / 8; + return product(ty.getShape()) * elSize; +} + +unsigned getAllocSize(ttg::LocalAllocOp op) { + return getMemDescSize(op.getType()); +} + +unsigned getAllocSize(ttng::TMEMAllocOp op) { + return getMemDescSize(op.getType()); +} + +unsigned getNumBuffers(ttg::MemDescIndexOp memdescIndexOp) { + ttg::MemDescType ty = + cast(memdescIndexOp.getSrc().getType()); + return ty.getShape()[0]; +} + +llvm::DenseSet getBarrierOperands(Operation *op) { + if (auto initBarrierOp = dyn_cast(op)) { + return {initBarrierOp.getOperand()}; + } + if (auto barrierExpectOp = dyn_cast(op)) { + return {barrierExpectOp.getAlloc()}; + } + if (auto invalBarrierOp = dyn_cast(op)) { + return {invalBarrierOp.getAlloc()}; + } + if (auto asyncOp = dyn_cast(op)) { + return {asyncOp.getBarrier()}; + } + if (auto gatherOp = dyn_cast(op)) { + return {gatherOp.getBarrier()}; + } + if (auto mmaV5Op = dyn_cast(op)) { + return llvm::DenseSet(mmaV5Op.getCompletionBarriers().begin(), + mmaV5Op.getCompletionBarriers().end()); + } + return llvm::DenseSet{}; +} + +bool isUsedAsBarrier(Value v) { + for (auto user : v.getUsers()) { + if (getBarrierOperands(user).contains(v)) { + return true; + } + } + return false; +} + +bool isUsedAsSharedMemory(Value v) { + auto type = dyn_cast(v.getType()); + return type && + isa_and_nonnull(type.getMemorySpace()); +} + +bool isUsedAsTensorMemory(Value v) { + auto type = dyn_cast(v.getType()); + return type && + isa_and_nonnull(type.getMemorySpace()); +} + +uint32_t getMemDescSubsliceByteOffset(ttg::MemDescSubsliceOp op) { + auto srcTy = op.getSrc().getType(); + auto offsets = op.getOffsets(); + if (offsets.empty()) + return 0; + + Attribute encoding = srcTy.getEncoding(); + mlir::triton::LinearLayout layout; + if (auto padded = dyn_cast(encoding)) { + layout = padded.getLinearComponent(); + } else { + layout = ttg::toLinearLayout(srcTy); + } + + MLIRContext *ctx = op->getContext(); + SmallVector dimNames = + mlir::triton::standardOutDimNames(ctx, srcTy.getRank()); + SmallVector> logicalOffsets; + logicalOffsets.reserve(offsets.size()); + for (auto &&[dimName, offset] : llvm::zip_equal(dimNames, offsets)) { + logicalOffsets.push_back({dimName, static_cast(offset)}); + } + + StringAttr offsetDim = StringAttr::get(ctx, "offset"); + layout = layout.sublayout({offsetDim}, dimNames); + mlir::triton::LinearLayout inverse = layout.invert(); + auto mapped = inverse.apply(logicalOffsets); + assert(mapped.size() == 1 && mapped[0].first == offsetDim && + "expected single offset dimension after inversion"); + uint64_t elementOffset = static_cast(mapped[0].second); + + uint64_t elementSizeBytes = + srcTy.getElementType().getIntOrFloatBitWidth() / 8; + assert(elementSizeBytes > 0 && "element size must be non-zero"); + uint64_t byteOffset = elementOffset * elementSizeBytes; + + if (auto padded = dyn_cast(encoding)) { + uint64_t padBytes = 0; + for (auto &&[interval, padding] : + llvm::zip_equal(padded.getIntervals(), padded.getPaddings())) { + if (interval == 0 || padding == 0) + continue; + uint64_t intervalScaled = + static_cast(interval) * elementSizeBytes; + uint64_t paddingScaled = + static_cast(padding) * elementSizeBytes; + assert(llvm::isPowerOf2_64(intervalScaled) && + llvm::isPowerOf2_64(paddingScaled) && + "interval and padding must be powers of two in bytes"); + unsigned intervalLog2 = llvm::Log2_64(intervalScaled); + unsigned paddingLog2 = llvm::Log2_64(paddingScaled); + padBytes += (byteOffset >> intervalLog2) << paddingLog2; + } + byteOffset += padBytes; + } + + assert(byteOffset <= std::numeric_limits::max() && + "memdesc_subslice offset exceeds 32-bit range"); + return static_cast(byteOffset); +} + +std::optional getRegionType(Value v) { + if (isUsedAsBarrier(v)) { + return triton::BufferRegionAnalysis::RegionType::BARRIER; + } + if (isUsedAsSharedMemory(v)) { + return triton::BufferRegionAnalysis::RegionType::SHARED_MEMORY; + } + if (isUsedAsTensorMemory(v)) { + return triton::BufferRegionAnalysis::RegionType::TENSOR_MEMORY; + } + return std::nullopt; +} + +} // namespace + +namespace mlir::triton { + +LogicalResult BufferRegionAnalysis::initialize(Operation *top) { + // Mark all warp-specialize partitions as live. + LogicalResult status = Base::initialize(top); + if (failed(status)) + return failure(); + + top->walk([&](ttg::WarpSpecializeOp wsOp) { + for (Region *region : wsOp.getPartitionRegions()) { + if (region->empty()) + continue; + Block &entry = region->front(); + auto *exec = + getOrCreate(getProgramPointBefore(&entry)); + propagateIfChanged(exec, exec->setToLive()); + } + }); + return success(); +} + +LogicalResult BufferRegionAnalysis::visitOperation( + Operation *op, + llvm::ArrayRef *> operands, + llvm::ArrayRef *> results) { + RegionInfo regionInfo; + if (auto wsOp = dyn_cast(op)) { + for (Region *region : wsOp.getPartitionRegions()) { + if (region->empty()) + continue; + + Block &entry = region->front(); + auto *exec = + getOrCreate(getProgramPointBefore(&entry)); + propagateIfChanged(exec, exec->setToLive()); + } + return success(); + } + if (auto localAllocOp = dyn_cast(op)) { + uint32_t offset = getAllocationOffset(localAllocOp); + uint32_t size = getAllocSize(localAllocOp); + regionInfo.regions.insert({offset, size}); + + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + if (auto tmemAllocOp = dyn_cast(op)) { + uint32_t offset = getAllocationOffset(tmemAllocOp); + uint32_t size = getAllocSize(tmemAllocOp); + regionInfo.regions.insert({offset, size}); + + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + if (auto memdescIndexOp = dyn_cast(op)) { + RegionInfo in = operands[0]->getValue(); + int numSubBuffers = getNumBuffers(memdescIndexOp); + for (auto ®ion : in.regions) { + for (int i = 0; i < numSubBuffers; i++) { + uint32_t subBufferSize = getMemDescSize(memdescIndexOp.getType()); + regionInfo.regions.insert( + {region.baseOffset + i * subBufferSize, subBufferSize}); + } + } + + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + if (auto memdescSubsliceOp = dyn_cast(op)) { + RegionInfo in = operands[0]->getValue(); + uint32_t subBufferSize = getMemDescSize(memdescSubsliceOp.getType()); + uint32_t relativeOffset = getMemDescSubsliceByteOffset(memdescSubsliceOp); + for (auto ®ion : in.regions) { + regionInfo.regions.insert( + {region.baseOffset + relativeOffset, subBufferSize}); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + if (auto tmemSubsliceOp = dyn_cast(op)) { + RegionInfo in = operands[0]->getValue(); + uint32_t subBufferSize = getMemDescSize(tmemSubsliceOp.getType()); + uint32_t relativeOffset = tmemSubsliceOp.getN(); + for (auto ®ion : in.regions) { + regionInfo.regions.insert( + {region.baseOffset + relativeOffset, subBufferSize}); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + // "Passthrough" ops that don't modify the buffer regions. + if (isa(op)) { + // Just propagate the regions from the operand. + RegionInfo in = operands[0]->getValue(); + for (auto ®ion : in.regions) { + regionInfo.regions.insert(region); + } + for (auto *r : results) { + propagateIfChanged(r, r->join(regionInfo)); + } + return success(); + } + verifyOpIsSupported(op); + return success(); +} + +void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) { + op->walk([&](Operation *op) { + auto insertRegionForValue = [&](Value v) { + RegionInfo regionInfo = getLatticeElement(v)->getValue(); + std::optional regionType = getRegionType(v); + if (!regionType) { + return; + } + for (auto ®ion : regionInfo.regions) { + usedBufferRegions[*regionType].insert(region); + } + }; + if (BufferRegionAnalysis::isMemoryAccessOperation(op)) { + // Allocas define their buffers with return value. + if (isa(op)) { + insertRegionForValue(op->getResult(0)); + } + // All other operations access their operands. + for (auto operand : op->getOperands()) { + insertRegionForValue(operand); + } + } + }); +} + +bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) { + if (isa( + op)) { + return true; + } + // Allocations with operands write to the memory. + if (isa(op) && + op->getNumOperands() > 0) { + return true; + } + if (isa(op)) { + return true; + } + return false; +} + +void BufferRegionAnalysis::verifyOpIsSupported(Operation *op) { + bool hasMemoryOperands = llvm::any_of(op->getOperands(), [](Value v) { + return isUsedAsSharedMemory(v) || isUsedAsTensorMemory(v); + }); + if (!hasMemoryOperands) { + return; + } + if (isMemoryAccessOperation(op)) { + return; + } + op->emitError( + "Operation accessing memory unaccounted for in buffer region analysis"); + llvm::report_fatal_error( + "Operation accessing memory unaccounted for in buffer region analysis"); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Analysis/CMakeLists.txt b/third_party/mthreads/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..f7bd302ded --- /dev/null +++ b/third_party/mthreads/lib/Analysis/CMakeLists.txt @@ -0,0 +1,23 @@ +add_triton_library(TritonAnalysis + AxisInfo.cpp + Allocation.cpp + BufferRegion.cpp + Membar.cpp + Alias.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUTableGen + TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen + TritonGPUOpInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR + GluonIR + TritonNvidiaGPUIR +) diff --git a/third_party/mthreads/lib/Analysis/Membar.cpp b/third_party/mthreads/lib/Analysis/Membar.cpp new file mode 100644 index 0000000000..5fe72f3da3 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Membar.cpp @@ -0,0 +1,394 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include + +namespace mlir { +namespace { + +bool shouldTrackMusaSquadDotOp(Operation *op) { + if (op->getName().getStringRef() != "ttmg.squad_dot") + return false; + if (op->getNumResults() == 0) + return false; + auto resultTy = dyn_cast(op->getResult(0).getType()); + if (!resultTy) + return false; + auto sqmmaEnc = + dyn_cast(resultTy.getEncoding()); + if (!sqmmaEnc) + return false; + unsigned totalWarps = 1; + for (unsigned w : sqmmaEnc.getWarpsPerCTA()) + totalWarps *= w; + return totalWarps > 4; +} + +} // namespace + +AllocationSlice::AllocationSlice(Value value, + Interval allocationInterval) + : allocationInterval(allocationInterval) { + auto accessTy = cast(value.getType()); + this->accessTy = accessTy; + + // Get the memdesc_subslice information if present. If no subslice is + // present the whole interval is accessed + if (auto subslice = value.getDefiningOp()) { + // We know there aren't subslices before the one because of subslice::fold + // Still need to check this for where a fold isn't possible (control flow) + // and when a subslice is carried in a loop + if (accessTy.getAllocShape() == subslice.getSrc().getType().getShape()) { + subsliceOffsets = SmallVector(subslice.getOffsets()); + } + } +} + +bool AllocationSlice::intersects(const AllocationSlice &other) const { + // Disjoint intervals don't overlap + if (!allocationInterval.intersects(other.allocationInterval)) + return false; + + // If access types are unknown, assume intersection + if (!accessTy || !other.accessTy) + return true; + + // If offsets are unknown, conservatively assume overlap + if (subsliceOffsets.empty() || other.subsliceOffsets.empty()) + return true; + + // If layouts differ, we assume intersection as we currently only work on + // logical elements + if (accessTy.getEncoding() != other.accessTy.getEncoding()) + return true; + + auto shapeA = SmallVector(accessTy.getShape()); + auto shapeB = SmallVector(other.accessTy.getShape()); + // Chek if all subslice region dimensions have some intersection + // [offsetA, offsetA + shape) and [offsetB, offsetB + other.shape) + // If any dimension doesn't intersect, we are looking at disjoint subslices + for (size_t i = 0; i < subsliceOffsets.size(); ++i) { + int64_t startA = subsliceOffsets[i]; + int64_t endA = startA + shapeA[i]; + int64_t startB = other.subsliceOffsets[i]; + int64_t endB = startB + shapeB[i]; + + // Is A completely before B? Is B completely before A? If so, disjoint + if (endA <= startB || endB <= startA) + return false; + } + + // All dimensions of subslices have some intersection + return true; +} + +void AllocationSlice::print(raw_ostream &os) const { + os << "interval=[" << allocationInterval.start() << "," + << allocationInterval.end() << ")"; + + os << " offsets=["; + if (!subsliceOffsets.empty()) { + llvm::interleaveComma(subsliceOffsets, os); + } else { + os << "unknown"; + } + os << "]"; + + os << " shape="; + if (accessTy) { + llvm::interleave(accessTy.getShape(), os, "x"); + os << " layout=" << accessTy.getEncoding(); + } else { + os << "? layout=unknown"; + } +} + +void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); +} + +void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + // Initialize the blockList. Operations are organized into "virtual blocks", + // which represent segments of straight-line code analyzed by each iteration + // of the dataflow analysis. Virtual blocks abstract over both control flow + // represented by basic blocks and block successors (i.e. `BranchOpInterface`) + // and control flow represented by regions (i.e. `RegionBranchOpInterface`). + // + // A virtual block consists of a parent block and a starting iterator, where + // the virtual block starts on the operation *after* the starting iterator. A + // null iterator is used to represent the beginning of the block. The virtual + // block ends at any region branch operation or the basic block terminator. + // Thus, basic blocks are broken up into multiple virtual blocks at each + // region operation. + // + // Entry virtual blocks are represented by a null iterator. Populate the + // blockList with the entry virtual blocks in the function. Then, each + // iteration scans until a terminator or region branch operation is found. + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + std::deque blockList; + // Start the analysis from the entry block of the function. + blockList.emplace_back(&funcOp.getBlocks().front(), Block::iterator()); + + // A fixed point algorithm + while (!blockList.empty()) { + VirtualBlock block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap[block]; + SmallVector successors; + Block::iterator startIt = + block.second.isValid() ? std::next(block.second) : block.first->begin(); + for (Operation &op : llvm::make_range(startIt, block.first->end())) { + // Update inputBlockInfo based on the current operation. Note that we do + // this before we process terminators and branch-like ops, because some of + // them (e.g. WarpSpecializePartitionsOp) may have synchronizing effects. + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); + if (op.hasTrait() || + isa(op)) { + visitTerminator(&op, successors); + break; + } + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block. The block transfer function is not monotonic, + // so overwrite the output state entirely. + outputBlockInfoMap[block] = inputBlockInfo; + // Update the successors + for (VirtualBlock successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); + } + } + + // Update the final dangling buffers that haven't been synced + BlockInfo &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](triton::ReturnOp returnOp) { + // A basic block can be broken into several virtual blocks. Find all virtual + // blocks that belong to the basic block containing the return. + SmallVector> virtualBlocks; + for (auto &[block, blockInfo] : outputBlockInfoMap) { + if (block.first == returnOp->getBlock()) + virtualBlocks.emplace_back(block, blockInfo); + } + // The return is a terminator, so the virtual block that contains this + // return starts after all other ones. Find it by comparing the start + // iterators of the virtual blocks. + auto maxIt = llvm::max_element(virtualBlocks, [&](auto &lhs, auto &rhs) { + assert(lhs.first.first == rhs.first.first); + Block::iterator lhsIt = lhs.first.second, rhsIt = rhs.first.second; + return !lhsIt.isValid() || + (rhsIt.isValid() && lhsIt->isBeforeInBlock(&*rhsIt)); + }); + + funcBlockInfo.join(maxIt->second); + }); +} + +void MembarOrFenceAnalysis::visitTerminator( + Operation *op, SmallVector &successors) { + if (isa(op)) { + // Collect the block successors of the branch. + for (Block *successor : op->getSuccessors()) + successors.emplace_back(successor, Block::iterator()); + return; + } + + if (auto br = dyn_cast(op)) { + // The successors of an operation with regions can be queried via an + // interface. The operation branches to the entry blocks of its region + // successors. It can also branch to after itself. + SmallVector regions; + br.getSuccessorRegions(RegionBranchPoint::parent(), regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + successors.emplace_back(br->getBlock(), br->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // FIXME: `ReturnLike` adds `RegionBranchTerminatorOpInterface` for some + // reason. Check that the parent is actually a `RegionBranchOpInterface`. + auto br = dyn_cast(op); + if (br && isa(br->getParentOp())) { + // Check the successors of a region branch terminator. It can branch to + // another region of its parent operation or to after the parent op. + SmallVector operands(br->getNumOperands()); + SmallVector regions; + br.getSuccessorRegions(operands, regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + Operation *parent = br->getParentOp(); + successors.emplace_back(parent->getBlock(), parent->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // Otherwise, it could be a return op + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + triton::gpu::BarrierOp::create(*builder, op->getLoc(), + triton::gpu::AddrSpace::Local); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + auto containsLocalBarrier = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + if (auto barrier = dyn_cast(op)) + return barrier.hasLocal(); + return false; + }; + + if (containsLocalBarrier(op)) { + // If the current op is a local barrier, we sync previous reads and writes + blockInfo->sync(); + return; + } + + if (op->hasTrait() && + !containsLocalBarrier(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + builder->setInsertionPointAfter(op); + insertBarrier(op, builder); + blockInfo->sync(); + return; + } + + BlockInfo curBlockInfo; + auto scratchBufferId = Allocation::InvalidBufferId; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + scratchBufferId = allocation->getBufferId(op); + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId && + bufferId != scratchBufferId) { + auto interval = allocation->getAllocatedInterval(bufferId); + auto slice = AllocationSlice(value, interval); + + if (isa(effectInstance.getEffect())) + curBlockInfo.syncWriteSlices[slice].insert(op); + else if (isa(effectInstance.getEffect())) + curBlockInfo.syncReadSlices[slice].insert(op); + } + } + } + } + } + if (shouldTrackMusaSquadDotOp(op)) { + for (Value operand : op->getOperands()) { + for (auto bufferId : allocation->getBufferIds(operand)) { + if (bufferId == Allocation::InvalidBufferId || + bufferId == scratchBufferId) + continue; + auto interval = allocation->getAllocatedInterval(bufferId); + auto slice = AllocationSlice(operand, interval); + curBlockInfo.syncReadSlices[slice].insert(op); + } + } + } + // If this op is may be signalling other threads asynchronously, make sure + // all shared memory transactions are complete beforehand. + if (isa(op)) { + Interval allIntervals(0, std::numeric_limits::max()); + auto allMemorySlice = AllocationSlice(allIntervals); + curBlockInfo.syncWriteSlices[allMemorySlice].insert(op); + curBlockInfo.syncReadSlices[allMemorySlice].insert(op); + } + } + + // Scratch buffer operations consist of a series of shared memory operations + // starting from a shared memory write, followed by a series of shared memory + // read/write operations, and ending with a shared memory read, i.e., shared + // memory write -> ... -> shared memory read. + if (scratchBufferId != Allocation::InvalidBufferId) { + // Detect warp-synchronous convert-layout operations. These emit a + // warp-level barrier (warp.sync) rather than a CTA-wide barrier between + // the internal shared-memory write and read phases. For these ops, we must + // not globally clear pending dependencies. + bool isWarpSync = false; + if (auto cvt = dyn_cast(op)) { + auto srcTy = cast(cvt.getSrc().getType()); + auto dstTy = cast(cvt.getType()); + auto srcLayout = triton::gpu::toLinearLayout(srcTy); + auto dstLayout = triton::gpu::toLinearLayout(dstTy); + isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout); + } + + if (!curBlockInfo.syncReadSlices.empty() || + !curBlockInfo.syncWriteSlices.empty()) { + llvm::report_fatal_error( + "scratch buffer operations should not have any shared memory " + "dependencies"); + } + auto interval = allocation->getAllocatedInterval(scratchBufferId); + auto scratchSlice = AllocationSlice(interval); + curBlockInfo.syncWriteSlices[scratchSlice].insert(op); + auto insertCTABarrier = + blockInfo->isIntersected(curBlockInfo, filter, allocation); + if (insertCTABarrier) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + } + // Ops with a scratch buffer that don't use warp.sync internally sync + // read/write on shared memory + if (insertCTABarrier || !isWarpSync) + blockInfo->sync(); + curBlockInfo.syncReadSlices[scratchSlice].insert(op); + } else if (blockInfo->isIntersected(curBlockInfo, filter, allocation)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + blockInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/Utility.cpp b/third_party/mthreads/lib/Analysis/Utility.cpp new file mode 100644 index 0000000000..16ec0157f5 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Utility.cpp @@ -0,0 +1,1168 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallSet.h" + +namespace mlir { + +using namespace triton; +using namespace triton::gpu; + +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto order = toLinearEncoding(srcTy).getOrder(); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto *ctx = srcEncoding.getContext(); + auto linearLayout = toLinearLayout(srcTy); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + const auto &bases = linearLayout.getBases(); + const auto &lanes = bases.find(kLane)->second; + auto offset = 1; + for (const auto &lane : lanes) { + if (lane[axis] != 0) + break; + offset *= 2; + } + return offset; +} + +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCGALayout == dstCGALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = getNumCTAs(srcLayout); + assert(numCTAs == getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + llvm::report_fatal_error("Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCGALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCGALayout == dstCGALayout + auto srcCGALayout = getCGALayout(srcLayout); + auto dstCGALayout = getCGALayout(dstLayout); + if (srcCGALayout == dstCGALayout) + return false; + + // Dsmem access is required when srcCGALayout != dstCGALayout + return true; +} + +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + return getWarpsPerCTA(srcEncoding, srcShape)[axis]; +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + return getThreadsPerWarp(srcEncoding, srcShape)[axis]; +} + +bool ReduceOpHelper::isWarpSynchronous() { + return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchRepShape() { + SmallVector smemShape; + // This case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(srcShape); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + + return smemShape; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + auto smemShape = getScratchRepShape(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +bool ReduceOpHelper::isReduceWithinCTA() { + // TODO: Support reduce across CTAS + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + return getCTASplitNum(srcEncoding)[axis] == 1; +} + +bool ReduceOpHelper::isAssociative() { + auto dtype = srcElementTypes[0]; + if (!type::isFloat(dtype)) + return true; + size_t reduce_size = srcShape[axis]; + if (reduce_size <= 2) + return true; + bool hasNoAssociativeOp = false; + op.walk([&](Operation *nestedOp) -> WalkResult { + if (isa(nestedOp)) { + // Only when the data type is float point and reduce size greater than 2, + // and has addf or mulf op, we though it's a non-associative reduce. + hasNoAssociativeOp = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return !hasNoAssociativeOp; +} + +ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + legacyEncoding = firstTy.getEncoding(); + // Remove broadcasting in the registers + // We also remove it in the lowering and re-add it when we pack the results + auto origLayout = triton::gpu::toLinearLayout(firstTy); + auto removeBroadcastRegs = actionRemoveBroadcastedRegs(origLayout); + origLayout = removeBroadcastRegs.apply(origLayout); + srcEncoding = triton::gpu::LinearEncodingAttr::get(op.getContext(), + std::move(origLayout)); + srcElementTypes = op.getElementTypes(); + // The codegen does not support different element/thread/warp order so + // we choose one a priori. We choose that of the blocked encoding. + // When we generalise this code to other layouts we'll probably need to + // get rid of all this logic and the *Stride auxiliary methods + // and replace them by transposes and reshapes on the LinearLayout + if (auto blockedEncoding = + dyn_cast(legacyEncoding)) { + order = llvm::to_vector(blockedEncoding.getOrder()); + } else { + order = srcEncoding.getOrder(); + } + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != legacyEncoding) { + op.emitError() << "encoding mismatch"; + } + } +} + +unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { + return getEncoding().getContigPerThread()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { + auto contigPerThread = getEncoding().getContigPerThread(); + contigPerThread[getAxis()] = 1; + return product(contigPerThread); +} + +Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return getEncoding().getThreadsPerWarp()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { + auto nThreads = product(getEncoding().getThreadsPerWarp()); + return nThreads / getAxisNumThreadsPerWarpWithUniqueData(); +} + +// Return the flat numbers of threads computing independent scan results. +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { + auto nWarps = product(getEncoding().getWarpsPerCTA()); + return (nWarps / getAxisNumWarpsWithUniqueData()) * + getNonAxisNumThreadsPerWarp(); +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return getEncoding().getWarpsPerCTA()[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumBlocks() { + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getEncoding().getThreadsPerWarp(); + auto warpsPerCTA = getEncoding().getWarpsPerCTA(); + unsigned axis = getAxis(); + return ceil( + getShape()[axis], + (contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); +} + +unsigned ScanLoweringHelper::getNonAxisNumBlocks() { + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getEncoding().getThreadsPerWarp(); + auto warpsPerCTA = getEncoding().getWarpsPerCTA(); + auto rank = contigPerThread.size(); + unsigned axis = getAxis(); + unsigned numBlocks = 1; + for (unsigned i = 0; i < rank; i++) { + if (i == axis) + continue; + numBlocks *= + ceil(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] * + warpsPerCTA[i])); + } + return numBlocks; +} + +bool ScanLoweringHelper::isSupported() { + // TODO: Support the following cases: + // 1. Scan on non-blocking encodings + if (!isa(legacyEncoding)) + return false; + return true; +} + +unsigned ScanLoweringHelper::getScratchSizeInElems() { + unsigned numWarps = product(getEncoding().getWarpsPerCTA()); + unsigned numNonAxisElementsPerWarp = + getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); + unsigned numElements = numWarps * numNonAxisElementsPerWarp * + getAxisNumBlocks() * getNonAxisNumBlocks(); + return numElements; +} + +unsigned ScanLoweringHelper::getScratchSizeInBytes() { + // Lowering will fail later if the layout is not supported. + if (!isSupported()) + return 0; + + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; + unsigned elementSizeInBytes = 0; + for (const auto &ty : srcElementTypes) { + elementSizeInBytes += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return elementSizeInBytes * getScratchSizeInElems(); +} + +static SmallVector +getTranspositionSelectors(SmallVector> &mixedTranspositions, + std::vector> ®Bases, + int bitwidth); + +DecomposedWarpConversion +getWarpLayoutConvertDecomposition(RankedTensorType srcTy, + RankedTensorType dstTy, int bitwidth) { + // Two layouts, ll_src and ll_dst, representing the same tensor can be + // viewed as surjections of GF(2) vector spaces: + // + // ll_src: H_src -> M and ll_dst: H_dst -> M, + // + // where each is represented by a 'subpermutation' matrix, i.e., a permutation + // matrix with zero columns possibly inserted. A layout conversion can be + // viewed as a map P': H_src -> H_dst which factors ll_src = ll_dst \circ P'. + // + // For a conversion not needing data movement between different warps, we + // choose the following representation, where P is a permutation matrix and + // K_1 and K_2 are (possibly trivial) spaces meant to ensure equally sized + // lane and register dimensions between layouts: + // P + // H_src -> H_src \oplus K_1 -------> H_dst \oplus K_2 -> H_dst. + // + // As a permutation, P can be viewed as a product of cycles permuting lane and + // register index bits. Any such permutation can be expressed as a composition + // + // P = P_mixed \circ P_lane \circ P_reg, + // + // where P_mixed is a product of disjoint transpositions (r_i l_j) between + // lane and register bits and where P_lane and P_reg are permutations purely + // involving lane bits and register bits, respectively. Such a representation + // is not unique, and we choose the factorization method which slices out + // subsequences of consecutive lane bits from cycles involving both bit types. + // Further explanation of this method is below. + // + // The decomposition is performed in three stages. First, we compute the + // permutation matrix `P` by using `invertAndCompose` to generate a skeleton + // and then fill in any zero columns. Second, we walk the cycles of `P` to + // factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and + // `pLane`. Finally, we determine any selectors needed for byte permute + // instructions in place of `selp` instructions when packing registers. + + // We remove any broadcasting in the register dimensions of the layouts before + // forming the permutation `P` as the components of the decomposition directly + // inform the number of emitted instructions, and leaving broadcasting in + // would unnecessarily inflate the count. + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout); + auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout); + srcLayout = removeBroadcastSrc.apply(srcLayout); + dstLayout = removeBroadcastDst.apply(dstLayout); + + // We want to describe the conversion from `srcLayout` to `dstLayout` as a + // permutation. Since this requires that each input dimension have the same + // size in each of the layouts, we first pad the lane and register dimensions + // with zero vectors if needed. + auto *ctx = srcTy.getContext(); + StringAttr kReg = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + + // Determine the target sizes of the register and lane dimensions for padding. + int nSrcRegBases = srcLayout.getInDimSizeLog2(kReg); + int nDstRegBases = dstLayout.getInDimSizeLog2(kReg); + int nSrcLaneBases = srcLayout.getInDimSizeLog2(kLane); + int nDstLaneBases = dstLayout.getInDimSizeLog2(kLane); + int nRegBases = std::max(nSrcRegBases, nDstRegBases); + int nLaneBases = std::max(nSrcLaneBases, nDstLaneBases); + // Restrict attention to the input dimensions which matter. + SmallVector inDimNames{kReg, kLane}; + auto outDimNames = llvm::to_vector(srcLayout.getOutDimNames()); + auto S = srcLayout.sublayout(inDimNames, outDimNames); + auto T = dstLayout.sublayout(inDimNames, outDimNames); + // Conditionally pad. + if (nSrcRegBases != nDstRegBases || nSrcLaneBases != nDstLaneBases) { + auto padWithZeros = [&](const LinearLayout &ll) { + auto newBases = ll.getBases(); + auto padDim = [&](StringAttr dim, int dimSize) { + auto &dimBases = newBases[dim]; + dimBases.reserve(dimSize); + for (int i = ll.getInDimSizeLog2(dim); i < dimSize; ++i) + dimBases.emplace_back(outDimNames.size(), 0); + }; + padDim(kReg, nRegBases); + padDim(kLane, nLaneBases); + // Surjectivity is not expected in general since we do not consider + // the 'warp' and 'block' dimensions of the original layouts. + return LinearLayout(std::move(newBases), ll.getOutDims(), + /*requireSurjective=*/false); + }; + S = padWithZeros(S); + T = padWithZeros(T); + } + + // We compute T^transpose \circ S, which serves as a skeleton for `P`, then + // fill in zero columns, prioritizing producing fixed points. As we only need + // the basis vectors of `P`, we never actually produce the LinearLayout. + auto pBases = S.invertAndCompose(T).getBases(); + + // Find the common and uncommon zeros of S and T + S = S.flattenOuts(); + T = T.flattenOuts(); + SmallVector> srcFreeZeros; + SmallVector> dstFreeZeros; + for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) { + for (int inIdx = 0; inIdx < S.getInDimSizeLog2(dim); ++inIdx) { + int sVal = S.getBasis(dim, inIdx)[0]; + int tVal = T.getBasis(dim, inIdx)[0]; + if (sVal == 0 && tVal == 0) { + pBases[dim][inIdx][dimIdx] = 1 << inIdx; + } else if (sVal == 0) { + srcFreeZeros.emplace_back(dimIdx, inIdx); + } else if (tVal == 0) { + dstFreeZeros.emplace_back(dimIdx, inIdx); + } + } + } + // Fill in non-fixed-point zero vectors + for (auto [srcZeroLoc, dstZeroLoc] : llvm::zip(srcFreeZeros, dstFreeZeros)) { + auto [srcDimIdx, srcIdx] = srcZeroLoc; + auto [dstDimIdx, dstIdx] = dstZeroLoc; + auto inDim = inDimNames[srcDimIdx]; + pBases[inDim][srcIdx][dstDimIdx] = 1 << dstIdx; + } + + // We walk the cycles of `P` to build the bases for `pReg` and `pLane` while + // factoring out mixed transpositions from cycles that include both register + // and lane basis vectors. `pReg` and `pLane` themselves only have one input + // and output dimension each. + LinearLayout::BasesT pRegBases, pLaneBases; + auto ®Bases = pRegBases[kReg]; + auto &laneBases = pLaneBases[kLane]; + regBases.resize(nRegBases, {0}); + laneBases.resize(nLaneBases, {0}); + SmallVector> mixedTranspositions; + + llvm::BitVector visited(nRegBases + nLaneBases, false); + auto flatIdx = [&](StringAttr dim, int32_t index) { + return (dim == kReg) ? index : nRegBases + index; + }; + + for (auto dim : inDimNames) { + int inDimSize = S.getInDimSizeLog2(dim); + for (int i = 0; i < inDimSize; ++i) { + if (visited.test(flatIdx(dim, i))) + continue; + + // Start a new cycle, tracking the entry basis vector and the 'current' + // one as we walk the cycle. + StringAttr entryDim = dim; + int32_t entryIdx = i; + StringAttr currDim = entryDim; + int32_t currIdx = entryIdx; + + // We slice out subsequences of consecutive lane basis vectors appearing + // in mixed cycles by factoring out transpositions (r_i l_j) as in + // + // (.. r_m l_j .. l_k r_i ..) = (r_i l_j) * (.. r_m r_i ..)(l_j .. l_k). + // + // The permutations are applied right-to-left, and the block `l_j .. l_k` + // indicates a contiguous subsequence of lane basis vectors. Note that the + // transposition does not commute with the other two cycles. + // + // The following variables are used to track the start and end points of + // such subsequences. + int32_t /*r_m*/ regStartIdx = -1; + int32_t /*l_j*/ laneStartIdx = -1; + int32_t /*l_k*/ laneEndIdx = -1; + int32_t /*r_i*/ regEndIdx = -1; + + do { + // Determine the next basis vector in the current cycle. + visited.set(flatIdx(currDim, currIdx)); + auto nextVec = pBases.lookup(currDim)[currIdx]; + StringAttr nextDim; + int32_t nextIdx; + for (auto [nextDimIdx, nextVal] : llvm::enumerate(nextVec)) { + if (nextVal != 0) { + nextDim = inDimNames[nextDimIdx]; + nextIdx = llvm::Log2_32(nextVal); + } + } + // Set a `pReg` or `pLane` vector, or mark an r->l or l->r transition. + if (currDim == kReg && nextDim == kReg) { + regBases[currIdx][0] = 1 << nextIdx; + } else if (currDim == kLane && nextDim == kLane) { + laneBases[currIdx][0] = 1 << nextIdx; + } else if (currDim == kReg && nextDim == kLane) { + regStartIdx = currIdx; + laneStartIdx = nextIdx; + } else { + regEndIdx = nextIdx; + laneEndIdx = currIdx; + } + // If a subsequence of the form (.. r_m l_j .. l_k r_i ..) has been + // found, perform the prescribed factorization. + if (regEndIdx >= 0) { + // Assign r_m to map to r_i as in (.. r_m r_i ..). + regBases[regStartIdx][0] = 1 << regEndIdx; + // Assign l_k to map to l_j as in (l_j .. l_k). + laneBases[laneEndIdx][0] = 1 << laneStartIdx; + // Record (r_i l_j) as a factor. + mixedTranspositions.emplace_back(regEndIdx, laneStartIdx); + // Reset the auxiliary variables. + regStartIdx = laneStartIdx = laneEndIdx = regEndIdx = -1; + } + + currDim = nextDim; + currIdx = nextIdx; + } while (flatIdx(currDim, currIdx) != flatIdx(entryDim, entryIdx)); + } + } + assert(visited.all() && "Cycle walk incomplete"); + + // Determine degree of packing and selectors. + int m = mixedTranspositions.size(); + int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); + int nPack = std::min(nPackPrelim, nRegBases - m); + auto processedTranspos = + getTranspositionSelectors(mixedTranspositions, regBases, nPack); + + auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}}, + /*requireSurjective=*/true); + auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}}, + /*requireSurjective=*/true); + return {std::move(pReg), std::move(pLane), std::move(processedTranspos), + nPack}; +} + +static SmallVector +getTranspositionSelectors(SmallVector> &mixedTranspositions, + std::vector> ®Bases, + int nPack) { + // When possible, we fuse permutations of 'low' register bits together + // with a mixed transposition, resulting in byte permute instructions instead + // of `select` instructions. After processing, no low register bits appear in + // the returned list of mixed transpositions. + + SmallVector ret; + ret.reserve(mixedTranspositions.size()); + if (nPack == 0) { + for (auto &t : mixedTranspositions) + ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); + return ret; + } + // Consider for example the cycle + // + // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) + // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) + // + // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to + // factor out any low bits from `pReg` and to incorporate them into the data + // of the mixed transposition. After processing, the contribution to `pReg` + // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with + // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. + // In general, low bits occurring immediately before l_j modify the selectors + // of the `prmt` before the shuffle, while low bits occurring immediately + // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified + // selectors correspond to `select` instructions. + // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is + // not used in another mixed transposition and conjugating out a low bit: + // + // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) + // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). + // + // Conjugation does not affect `pReg`. However, the set of fused mixed and + // low-bit transpositions is noncommutative in cases where there are no + // intervening high bits in between distinct sequences of lane bits as the + // paired low bit is used in modifying the selectors of both factors: + // + // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). + // + // The `*` is standard composition of permutations. The groupings correspond + // to different `TranspositionInfo` objects. For example, the permutation + // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with + // pre- and post-shuffle selectors determined by the `r0` bit. + // Processing of mixed transpositions is performed by determining the `head` + // and `tail` of an excision of bits in cycles of `pReg` and building lists + // of low bits acting as selector modifiers. In the noncommutative cases, we + // opt to restrict the number of post-shuffle modifiers to one. + + auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { + int lo = bitIdx + (2 - nPack); + uint16_t maskHi = 0x4444; + uint16_t maskLo = 0x1111 << lo; + uint16_t fixed = sel & ~maskHi & ~maskLo; + int shift = 2 - lo; + return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); + }; + auto generateSelectors = [&](int head, int tail, auto &&lowBits) { + uint16_t topSel = 0x3210; + uint16_t botSel = 0x7654; + for (auto lowBit : lowBits) { + topSel = permuteSelector(topSel, lowBit); + botSel = permuteSelector(botSel, lowBit); + if (lowBit != head && lowBit != tail) + regBases[lowBit][0] = 1 << lowBit; + } + return std::pair{topSel, botSel}; + }; + + llvm::SmallSet pairedRegBits; + for (auto [rBit, lBit] : mixedTranspositions) + pairedRegBits.insert(rBit); + + // A low bit in a mixed transposition must be replaced by a high bit. The + // choice of high bit can affect instruction count. If the first high bit + // found when walking along `pReg` is unpaired, then that bit is the best + // choice. We reorder the transpositions to guarantee this during processing. + auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; + auto nextHighFree = [&](auto p) { + int curr = p.first; + do { + if (curr >= nPack) + return curr == p.first || !pairedRegBits.contains(curr); + curr = next(curr); + } while (curr != p.first); + return false; + }; + std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), + nextHighFree); + // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low + // bit to an open high bit, then the high bit should be used as the partner. + auto prev = [&](int b) { + int tail = b; + int curr = next(b); + while (curr != b) { + tail = curr; + curr = next(curr); + } + return tail; + }; + auto findPartner = [&](int lowBit, auto &preShufLoBits) { + if (nPack == 2) { + int otherLow = 1 - lowBit; + int b = next(otherLow); + if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && + !pairedRegBits.contains(otherLow)) { + preShufLoBits.push_back(otherLow); + regBases[prev(otherLow)][0] = 1 << b; + pairedRegBits.insert(b); + return b; + } + } + int potentialPartner = nPack; + while (pairedRegBits.contains(potentialPartner)) + ++potentialPartner; + pairedRegBits.insert(potentialPartner); + return potentialPartner; + }; + + for (auto p : mixedTranspositions) { + int rBit = p.first; + int lBit = p.second; + SmallVector cycle; + int currBit = rBit; + do { + cycle.push_back(currBit); + currBit = next(currBit); + } while (currBit != rBit); + + // Find any low register bits adjacent to the excised lane bits which aren't + // used in other mixed transpositions. + auto isBoundary = [&](int bit) { + return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); + }; + auto forwardEnd = llvm::find_if(cycle, isBoundary); + auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); + SmallVector postShufLoBits(cycle.begin(), forwardEnd); + SmallVector preShufLoBits(cycle.rbegin(), backwardEnd); + int head; + int tail; + int partnerBit = -1; + + // Case work to determine what to conjugate out. + if (forwardEnd != cycle.end()) { + if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { + // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) + // No conjugation needed. + head = partnerBit = *forwardEnd; + } else { + // End at different paired bit. E.g. (l0 r0 r1 l1 r2) + // Non-leading factor in a noncommutative case. + // Conjugate by first low bit in forward walk. + head = postShufLoBits.front(); + preShufLoBits.push_back(head); + postShufLoBits.resize(1); + pairedRegBits.erase(head); + } + tail = *backwardEnd; + if (tail < nPack && pairedRegBits.contains(tail)) { + // Non-terminal factor in a noncommutative case. + preShufLoBits.insert(preShufLoBits.begin(), tail); + } + } else { + if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { + // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) + preShufLoBits.erase(preShufLoBits.begin()); + postShufLoBits.pop_back(); + pairedRegBits.erase(postShufLoBits.front()); + head = rBit; + tail = next(rBit); + } else { + // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) + if (postShufLoBits.size() == 2) + postShufLoBits.pop_back(); + head = tail = preShufLoBits.front(); + } + } + + if (partnerBit < 0) + partnerBit = findPartner(head, preShufLoBits); + auto [topPostSel, botPostSel] = + generateSelectors(head, tail, llvm::reverse(postShufLoBits)); + auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); + regBases[tail][0] = 1 << head; + + DecomposedWarpConversion::TranspositionInfo info; + info.transposition = {partnerBit, lBit}; + info.topPreSel = topPreSel; + info.botPreSel = botPreSel; + info.topPostSel = topPostSel; + info.botPostSel = botPostSel; + + // In noncommutative cases, post-shuffle selectors of non-leading terms come + // from a single low bit by design, so we can determine where to insert a + // non-terminal factor by examining processed selectors. + if (!preShufLoBits.empty()) { + uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; + auto it = + llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); + ret.insert(it, info); + } else { + ret.push_back(info); + } + } + if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { + // If (r0 r1) was originally in `P`, fold it into a mixed transposition. + auto &t = ret.back(); + t.topPostSel = 0x3120; + t.botPostSel = 0x7564; + } + return ret; +} + +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, + ArrayRef dstShape) { + SmallVector, SmallVector>> ret; + + if (srcShape.empty()) { + assert(dstShape.empty()); + return ret; + } + ret.push_back({}); + + int srcIdx = 0; + int dstIdx = 0; + int srcNElems = 1; + int dstNElems = 1; + while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) { + if (srcNElems < dstNElems || // + (srcIdx < srcShape.size() && srcNElems == 1) || + (srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) { + assert(srcIdx < srcShape.size()); + srcNElems *= srcShape[srcIdx]; + ret.back().first.push_back(srcIdx); + srcIdx++; + } else if (dstNElems < srcNElems || + (dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) { + assert(dstIdx < dstShape.size()); + dstNElems *= dstShape[dstIdx]; + ret.back().second.push_back(dstIdx); + dstIdx++; + } else { + ret.push_back({}); + srcNElems = 1; + dstNElems = 1; + } + } + return ret; +} + +unsigned ScanLoweringHelper::getAxisElementStride() { + auto order = getOrder(); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getEncoding().getContigPerThread()[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisThreadStride() { + auto encoding = getEncoding(); + auto ll = encoding.getLinearLayout(); + auto kThread = StringAttr::get(encoding.getContext(), "lane"); + const auto &bases = ll.getBases().lookup(kThread); + unsigned axis = getAxis(); + for (unsigned i = 0; i < bases.size(); ++i) { + if (bases[i][axis] != 0) + return 1 << i; + } + return 1; +} + +unsigned ScanLoweringHelper::getAxisBlockStride() { + auto order = getOrder(); + unsigned stride = 1; + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getEncoding().getThreadsPerWarp(); + auto warpsPerCTA = getEncoding().getWarpsPerCTA(); + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= ceil(getShape()[dim], contigPerThread[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); + } + llvm_unreachable("Axis not found in order"); +} + +GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) + : gatherOp(gatherOp) {} + +unsigned GatherLoweringHelper::getScratchSizeInBytes() { + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. + RankedTensorType srcType = gatherOp.getSrc().getType(); + return product(srcType.getShape()) * + ceil(srcType.getElementTypeBitWidth(), 8); +} + +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source and index tensors, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + RankedTensorType idxType = gatherOp.getIndices().getType(); + LinearLayout srcLayout = toLinearLayout(srcType); + LinearLayout idxLayout = toLinearLayout(idxType); + + Builder b(gatherOp.getContext()); + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix (permutation matrix plus zeros for broadcasting). + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a permutation matrix (checked below), we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + return false; + + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. + if (srcLayout.sublayout({kBlock, kWarp}, otherDims) != + idxLayout.sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + return srcLayout.sublayout(kLane, otherDims) == + idxLayout.sublayout(kLane, otherDims); +} + +unsigned getNumScratchElements(ArrayRef shape) { + if (shape.empty()) + return 0; + return product(shape); +} + +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.getA().getType().getElementType(); + auto bElemTy = op.getB().getType().getElementType(); + if (version == 5) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V5")) + return false; + RankedTensorType typeA = op.getA().getType(); + int k = typeA.getShape().back(); + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + int numWarps = lookupNumWarps(op); + if (aElemTy.isInteger() || bElemTy.isInteger() || + retType.getElementType().isInteger()) + return false; + if (op.getType().getRank() != 2) + return false; + if (numWarps != 4 && numWarps != 8) { + // Currently only support numWarps 4 or 8 for TMEM load and store. + return false; + } + // If k size is smaller than the native mma size, we cannot use MMA. + if (k < 256 / aElemTy.getIntOrFloatBitWidth()) + return false; + if (!(retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 16 == 0)) + return false; + return true; + } + if (version == 3) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return false; + auto retType = op.getType(); + RankedTensorType typeA = op.getA().getType(); + int k = typeA.getShape().back(); + // If k size is smaller than the native mma size, we cannot use MMA. + if (k < 256 / aElemTy.getIntOrFloatBitWidth()) + return false; + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + int numWarps = lookupNumWarps(op); + // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. + if (rank == 3) + return false; + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 16 == 0 && + (llvm::isa(aElemTy) || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (llvm::isa(aElemTy)) && + cast(op.getType()).getElementType().isF32()) { + return false; + } + } + if (aElemTy.isF32() && bElemTy.isF32()) { + return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + } + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + +bool supportMMA(Value value, int version) { + // Tell whether a DotOp support MMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + assert((version == 1 || version == 2 || version == 3) && + "Unexpected MMA layout version found"); + auto elemTy = + cast(value.getType()).getElementType(); + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = llvm::isa(elemTy); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || + ((elemTy.isF32() || elemTy.isF64()) && version >= 2) || + (elemTy.isInteger(8) && version >= 2); +} + +// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity +// under the common dimensions. The idea here is that if we have a +// transformation that's the identity on kBlock, we don't need to use +// distributed shared memory. If it's also the identity on kWarp, we can +// transfer via warp-shuffles, and if it's the identity on kLane just have to +// reorder the registers. +LinearLayout minimalCvtLayout(Type srcTy_, Type dstTy_) { + auto srcTy = cast(srcTy_); + auto dstTy = cast(dstTy_); + LinearLayout srcLayout = toLinearLayout(srcTy); + LinearLayout dstLayout = toLinearLayout(dstTy); + auto sDims = to_vector(srcLayout.getInDimNames()); + auto dDims = to_vector(dstLayout.getInDimNames()); + SmallVector dims; + for (int i = 0; i < std::min(sDims.size(), dDims.size()); ++i) { + auto srcDim = sDims[sDims.size() - i - 1]; + auto dstDim = dDims[dDims.size() - i - 1]; + if (srcDim != dstDim) { + break; + } + dims.push_back(srcDim); + } + + auto comp = dstLayout.invertAndCompose(srcLayout); + // We try to quotient by the slowers moving subspace first + for (auto dim : dims) { + auto quotient = comp.quotient(dim); + if (!quotient.has_value()) { + break; + } + comp = *quotient; + } + return comp; +} + +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = to_vector(layout.getOutDimNames()); + return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); +} + +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + if (to_vector(layout.getOutDimNames()) == + SmallVector{kRegister, kLane}) { + auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32); + return (factors.mixedTranspositions.size() < 2); + } + return false; +} + +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { + return !cvtReordersRegisters(srcTy, dstTy) && + !cvtNeedsWarpShuffle(srcTy, dstTy); +} + +namespace { + +/// A data structure similar to SetVector but maintains +/// a deque instead of a vector to allow for efficient +/// push_back and pop_front operations. +/// Using SetVector doesn't suffice our needs because +/// it only pushes and pops from the back. +/// For example, if we have a queue like this: +/// 0->4 1->2->3 +/// ^-------- +/// where 3 depends on 4, once we pop 3, we found +/// 4 is not ready, so we check 2 and push 3 back +/// to the queue. +struct DFSSubgraphState { + DFSSubgraphState() : set(), deque() {} + DenseSet set; + std::deque deque; + + bool push_back(Operation *op) { + if (set.insert(op).second) { + deque.push_back(op); + return true; + } + return false; + } + + Operation *pop_front() { + Operation *op = deque.front(); + deque.pop_front(); + set.erase(op); + return op; + } + + bool empty() { return deque.empty(); } +}; + +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector &set) : toSort(set), seen() {} + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; + + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. + void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, + SmallVector &readyQueue) { + bool ready = true; + for (Value operand : op->getOperands()) { + auto def = operand.getDefiningOp(); + if (def && !seen.count(def)) { + subGraph.push_back(def); + ready = false; + } + } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } + if (ready) + readyQueue.push_back(op); + } +}; + +void dfsPostorder(Operation *root, DFSState *state) { + DFSSubgraphState subGraph; + subGraph.push_back(root); + SmallVector ops; + while (!subGraph.empty()) { + // Nodes in the ready queue are ready to be processed. + // Meaning that either their operands are all seen or they have null + // operands. + SmallVector readyQueue; + auto *current = subGraph.pop_front(); + state->addToReadyQueue(current, subGraph, readyQueue); + while (!readyQueue.empty()) { + Operation *current = readyQueue.pop_back_val(); + if (!state->seen.insert(current).second) + continue; + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + state->addToReadyQueue(op, subGraph, readyQueue); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + state->addToReadyQueue(&op, subGraph, readyQueue); + } + } + } + + for (Operation *op : llvm::reverse(ops)) { + if (state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } +} + +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + +bool isCvtWarpSync(const triton::LinearLayout &srcLayout, + const triton::LinearLayout &dstLayout) { + // We can use warp.sync when the warp dimension in the convert is trival + // and there is no broadcasting at a warp level (otherwise reads may be + // wrong) + auto *ctx = srcLayout.getInDimNames().begin()->getContext(); + auto comp = dstLayout.invertAndCompose(srcLayout); + auto kWarp = StringAttr::get(ctx, "warp"); + return comp.isTrivialOver(kWarp) && + srcLayout.getFreeVariableMasks()[kWarp] == 0 && + dstLayout.getFreeVariableMasks()[kWarp] == 0; +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/CMakeLists.txt b/third_party/mthreads/lib/CMakeLists.txt new file mode 100644 index 0000000000..c58b7fa0a3 --- /dev/null +++ b/third_party/mthreads/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) +add_subdirectory(Tools) diff --git a/third_party/mthreads/lib/Conversion/CMakeLists.txt b/third_party/mthreads/lib/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..84aba4f3d2 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonInstrumentToLLVM) diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 0000000000..0448fbc73a --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,27 @@ +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct AllocateSharedMemory + : public mlir::triton::gpu::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + ModuleAllocation allocation(mod); + + mlir::triton::gpu::attachAllocationSizeAndOffsetAttr(mod, allocation); + } +}; +} // namespace diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.cpp new file mode 100644 index 0000000000..24e90a2460 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.cpp @@ -0,0 +1,34 @@ +#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h" + +namespace mlir::triton::gpu { + +void attachAllocationSizeAndOffsetAttr(ModuleOp mod, + ModuleAllocation &allocation) { + MLIRContext *ctx = mod.getContext(); + + mod.walk([&](FunctionOpInterface funcOp) { + auto *funcAllocation = allocation.getFuncData(funcOp); + funcOp.walk([&](Operation *op) { + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + return WalkResult::skip(); + }); + mod->setAttr("ttg.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp new file mode 100644 index 0000000000..1e718a6866 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp @@ -0,0 +1,217 @@ +#include "mlir/IR/BuiltinOps.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUALLOCATEWARPGROUPS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Given a `ttg.warp_specialize` with a certain number of existing warps, pad it +// with extra warps until it has the same number of full warp groups as the +// largest partitioning. This ensures that all threads can be present to +// surrender registers. +static void padToMaxWarpGroups(WarpSpecializeOp op, int numExtraWarpGroups) { + int numExtraWarps = op.getTotalPartitionWarps(); + int warpsToAdd = numExtraWarpGroups * 4 - numExtraWarps; + assert(warpsToAdd >= 0); + + // Fill it with powers of 2. + SmallVector paddingPartitionSizes; + while (warpsToAdd > 0) { + int paddingSize = llvm::NextPowerOf2(warpsToAdd) / 2; + paddingPartitionSizes.push_back(paddingSize); + warpsToAdd -= paddingSize; + } + + auto partitions = cast( + op.getPartitionOpHolder().front().front()); + OperationState state(partitions.getLoc(), partitions.getOperationName(), + partitions.getOperands(), /*types=*/{}); + for (Region *region : partitions.getRegions()) + state.addRegion()->takeBody(*region); + + SmallVector partitionNumWarps(op.getPartitionNumWarps()); + for (int paddingSize : paddingPartitionSizes) { + partitionNumWarps.push_back(paddingSize); + + Block &body = state.addRegion()->emplaceBlock(); + for (Value capture : op.getPartitionOp().getExplicitCaptures()) + body.addArgument(capture.getType(), capture.getLoc()); + OpBuilder b(op.getContext()); + b.setInsertionPointToStart(&body); + WarpReturnOp::create(b, op.getLoc()); + } + op.setPartitionNumWarps(partitionNumWarps); + + // Set the requested registers to low for the padded partitions that do + // nothing. + if (auto reqRegs = op.getRequestedRegisters()) { + SmallVector newReqRegs(*reqRegs); + newReqRegs.append(paddingPartitionSizes.size(), 16); + op.setRequestedRegisters(newReqRegs); + } + + OpBuilder b(partitions); + b.create(state); + partitions.erase(); +} + +namespace { +struct AllocateWarpGroups + : public mlir::triton::gpu::impl::TritonGPUAllocateWarpGroupsBase< + AllocateWarpGroups> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First determine the maximum number of extra warps. + int maxExtraWarps = 0; + mod.walk([&](WarpSpecializeOp op) { + maxExtraWarps = std::max(maxExtraWarps, op.getTotalPartitionWarps()); + }); + + // Round this up to the nearest warpgroup (multiple of 4) and then pad each + // `ttg.warp_specialize` to the nearest warpgroup. + int numExtraWarpGroups = llvm::divideCeil(maxExtraWarps, 4); + mod.walk([&](WarpSpecializeOp op) { + padToMaxWarpGroups(op, numExtraWarpGroups); + }); + + int baseNumWarps = lookupNumWarps(mod); + + // Compute the total number of warps required at any given time. + mod.walk([&](WarpSpecializeOp op) { + ArrayRef arr = op.getPartitionNumWarps(); + + // Allocate the start IDs such that the largest warpgroups have lower + // starting warp IDs. + // FIXME: Handle aligning warp group IDs to 4 for TMEM. + SmallVector> idxAndSize; + for (auto [i, size] : llvm::enumerate(arr)) + idxAndSize.emplace_back(i, size); + llvm::sort(idxAndSize, + [&](auto lhs, auto rhs) { return lhs.second > rhs.second; }); + + SmallVector startIds(arr.size()); + int startId = baseNumWarps; + for (auto [i, size] : idxAndSize) { + startIds[i] = startId; + startId += size; + } + op.setWarpGroupStartIds(startIds); + }); + + Builder b(&getContext()); + mod->setAttr("ttg.total-num-warps", + b.getI32IntegerAttr(baseNumWarps + numExtraWarpGroups * 4)); + + bool needsRegisterOptimization = false; + mod.walk([&](WarpSpecializeOp op) { + if (op.getRequestedRegisters()) + needsRegisterOptimization = true; + }); + + if (!needsRegisterOptimization) + return; + + // Determine the maximum number of registers per thread. This may have + // been set by the user. + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int maxnreg; + if (auto maxnregAttr = + mod->getAttrOfType(AttrMaxRegistersName)) { + maxnreg = maxnregAttr.getInt(); + } else { + // Assume the user wants to use all 64K registers. + maxnreg = (64 * 1024) / (baseNumWarps + numExtraWarpGroups * 4) / + threadsPerWarp; + maxnreg = maxnreg / 8 * 8; + } + + struct WarpGroupInfo { + SmallVector partitions; + int maxRequestedRegs = 0; + unsigned numWarps = 0; + }; + struct WarpGroupPartition { + int startId; + Region *partition; + int32_t estRegs; + int numWarps; + }; + + // Compute register allocation for each warp specialize op. + mod.walk([&](WarpSpecializeOp op) { + ArrayRef arr = op.getPartitionNumWarps(); + auto startIds = *op.getWarpGroupStartIds(); + + // Require that an estimate has been set and that we have even warpgroups. + auto regsAttr = op.getRequestedRegisters(); + if (!regsAttr || op.getTotalPartitionWarps() % 4 != 0) + return; + + // Group the partitions into warpgroups. + SmallVector orderedPartitions; + for (auto [startId, partition, estRegs, numWarps] : + llvm::zip(startIds, op.getPartitionRegions(), *regsAttr, arr)) + orderedPartitions.push_back({startId, partition, estRegs, numWarps}); + llvm::sort(orderedPartitions, + [&](auto lhs, auto rhs) { return lhs.startId < rhs.startId; }); + + // Iterate over the partitions and assign them to warp groups. Determine + // the maximum number of requested registers per warp group. + SmallVector warpGroups; + for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) { + if (startId % 4 == 0) { + warpGroups.push_back(WarpGroupInfo{}); + } + warpGroups.back().partitions.push_back(partition); + // Round up the nearest multiple of 8. + int estRegsCeil8 = llvm::divideCeil(estRegs, 8) * 8; + warpGroups.back().maxRequestedRegs = + std::max(warpGroups.back().maxRequestedRegs, estRegsCeil8); + warpGroups.back().numWarps += numWarps; + } + + // Compute the register deficit over the partition warp groups. + int registerBudget = maxnreg * baseNumWarps * threadsPerWarp; + for (const WarpGroupInfo &wg : warpGroups) { + assert(wg.numWarps % 4 == 0); + registerBudget += + (maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp; + } + if (registerBudget <= 0) + return; + + // Determine the number of extra registers that we can distribute to the + // default warp group. + int leftover = registerBudget / (baseNumWarps * threadsPerWarp); + // Round down to the nearest multiple of 8. + leftover = leftover / 8 * 8; + if (leftover < 24) + return; // too few registers + + // Generate setmaxnreg in each partition according to its warp group. + SmallVector maxnregsPerPartition(1 + arr.size()); + for (const WarpGroupInfo &wg : warpGroups) { + for (Region *region : wg.partitions) { + maxnregsPerPartition[1 + region->getRegionNumber()] = + wg.maxRequestedRegs; + } + } + // Set the register usage for the default warp group. + maxnregsPerPartition.front() = leftover; + op.setActualRegisters(maxnregsPerPartition); + + // Set the initial max number of registers. This is needed for PTXAS to + // cooperate. + mod->setAttr(AttrMaxRegistersName, + Builder(op.getContext()).getI32IntegerAttr(maxnreg)); + }); + } +}; +} // namespace diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 0000000000..9c17af5aa6 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,106 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = b.or_(condition, + b.icmp_eq(elem, LLVM::ConstantOp::create( + rewriter, loc, elemTy, + rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + b.barrier(triton::gpu::AddrSpace::None); + } + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + ConversionPatternRewriter &rewriter) const { + + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + while (auto nameLoc = dyn_cast(loc)) + loc = nameLoc.getChildLoc(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + LLVM::BrOp::create(rewriter, loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + LLVM::CondBrOp::create(rewriter, loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..2c1d48b5f9 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,41 @@ +add_triton_library(TritonGPUToLLVM + DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp + AllocateSharedMemory.cpp + AllocateSharedMemoryUtility.cpp + AllocateWarpGroups.cpp + AssertOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + FuncOpToLLVM.cpp + GatherOpToLLVM.cpp + GlobalScratchMemoryAllocation.cpp + HistogramOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + MemoryOpToLLVM.cpp + PrintOpToLLVM.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + SPMDOpToLLVM.cpp + TypeConverter.cpp + Utility.cpp + ViewOpToLLVM.cpp + WarpSpecializeUtility.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 0000000000..f33cb37cbf --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,165 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = LLVM::ReturnOp::create(rewriter, op.getLoc(), + adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + LLVM::UndefOp::create(rewriter, op.getLoc(), packedResultsTy); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = b.insert_val(packedResultsTy, packedResults, + it.value(), it.index()); + } + newOp = LLVM::ReturnOp::create(rewriter, op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset") || + !callOp->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + } else { + auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); + promotedOperands.push_back(base); + } + + auto opOffsetAttr = callOp->getAttrOfType( + "ttg.global_scratch_memory_offset"); + Value opOffsetVal; + if (opOffsetAttr) { + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + opOffsetVal = b.i32_val(opOffset); + } + + promotedOperands.push_back(LLVM::getGlobalScratchPtr( + loc, rewriter, targetInfo, caller, opOffsetVal)); + promotedOperands.push_back( + LLVM::getProfileScratchPtr(loc, rewriter, caller)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = LLVM::CallOp::create(rewriter, callOp.getLoc(), + packedResult ? TypeRange(packedResult) + : TypeRange(), + promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(LLVM::ExtractValueOp::create( + rewriter, callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 0000000000..fa7535478f --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,603 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton::gpu; +using TranspositionInfo = DecomposedWarpConversion::TranspositionInfo; + +constexpr int kPtrBitWidth = 64; +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + + explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + LinearLayout srcLayout = toLinearLayout(srcTy); + LinearLayout dstLayout = toLinearLayout(dstTy); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + + assert(to_vector(conversion.getInDimNames()) == + to_vector(conversion.getOutDimNames())); + auto dims = conversion.getInDimNames(); + if (llvm::is_contained(dims, kBlock)) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, kWarp)) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter); + return success(); + } else if (llvm::is_contained(dims, kLane)) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // expensive enough we fall back to using shared memory + if (cvtNeedsWarpShuffle(srcTy, dstTy)) + return transferWithinWarp(op, adaptor, rewriter); + + transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter); + return success(); + } else if (llvm::is_contained(dims, kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } + + LogicalResult + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + StringAttr kRegister = str_attr("register"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + SmallVector transferWithinBlockSwizzlingImpl( + Location loc, ConversionPatternRewriter &rewriter, + const LinearLayout &srcLayout, const LinearLayout &dstLayout, + ArrayRef inVals, Type llvmElemTy, Value smemBase) const { + auto *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // We handle transformations recursively as they all need a preprocessing + // and a postprocessing step. + + // Handle pointer types as 64-bit integers + if (isa(llvmElemTy)) { + auto llvmElemTyPtr = i64_ty; + auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) { + return b.ptrtoint(llvmElemTyPtr, v).getResult(); + })); + auto outVals = + transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout, + newInVals, llvmElemTyPtr, smemBase); + for (auto &v : outVals) { + v = b.inttoptr(llvmElemTy, v); + } + return outVals; + } + + // Handle sub-byte elements like i1 + if (llvmElemTy.getIntOrFloatBitWidth() < 8) { + // Upcast to i8 + auto i8ElemTy = i8_ty; + auto newInVals = llvm::to_vector(llvm::map_range( + inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); })); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase); + for (auto &v : outVals) { + v = b.trunc(llvmElemTy, v); + } + return outVals; + } + + // Remove broadcasting in src + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtSrc = removeBroadcastSrc.apply(srcLayout); + auto newInVals = removeBroadcastSrc.apply(inVals); + return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout, + newInVals, llvmElemTy, smemBase); + } + + // Remove broadcasting in dst + auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout); + if (!removeBroadcastDst.isIdentity()) { + auto prmtDst = removeBroadcastDst.apply(dstLayout); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase); + return broadcastAs(outVals, dstLayout); + } + + // At this point we have a type that's at least 8-bit + // and we don't have broadcasting in the registers + auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth); + + // Extract reps from smem + auto kReg = str_attr("register"); + auto kReps = str_attr("reps"); + auto nReps = smem.getInDimSize(kReps); + auto reps = LinearLayout::identity1D(nReps, kReg, kReps); + + auto totalStoreCvt = srcLayout.invertAndCompose(smem); + auto totalLoadCvt = dstLayout.invertAndCompose(smem); + + // The permutation exists by construction of the reps dimension in + // optimalSwizzling + auto permStore = + regPermForDivide(totalStoreCvt, reps, /*left=*/false).value(); + totalStoreCvt = permStore.apply(totalStoreCvt); + auto permutedInVals = permStore.apply(inVals); + auto permLoad = + regPermForDivide(totalLoadCvt, reps, /*left=*/false).value(); + totalLoadCvt = permLoad.apply(totalLoadCvt); + + // Remove the reps and flatten into offset + auto storeCvt = *divideRight(totalStoreCvt, reps); + auto loadCvt = *divideRight(totalLoadCvt, reps); + auto kOffset = str_attr("offset"); + storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}}); + loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}}); + + auto tileSize = storeCvt.getInDimSize(kReg); + + assert(permutedInVals.size() == tileSize * nReps); + SmallVector outVals; + auto affineOffset = b.i32_val(0); + auto maskSpanAffineOffset = 0; + + bool isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout); + for (int i = 0; i < nReps; ++i) { + if (i > 0) { + if (isWarpSync) { + targetInfo.warpSync(loc, rewriter); + } else { + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + } + } + auto tileInVals = + ArrayRef(permutedInVals).slice(i * tileSize, tileSize); + // Store + lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase, + /*paddingShifts=*/{}, affineOffset, maskSpanAffineOffset, + rewriter, targetInfo); + if (isWarpSync) { + targetInfo.warpSync(loc, rewriter); + } else { + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + } + // Load + SmallVector tileOutVals = lowerLdStShared( + loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /*paddingShifts=*/{}, + affineOffset, maskSpanAffineOffset, rewriter, targetInfo); + llvm::append_range(outVals, tileOutVals); + } + + // Undo the permLoad used to divideRight + outVals = permLoad.inverse().apply(outVals); + return outVals; + } + + void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + // Remove the kBlock dimension from the layout as it's the identity in the + // cvt + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + srcLayout = srcLayout.sublayout({kReg, kLane, kWarp}, + to_vector(srcLayout.getOutDimNames())); + dstLayout = dstLayout.sublayout({kReg, kLane, kWarp}, + to_vector(dstLayout.getOutDimNames())); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto inVals = unpackLLElements(loc, src, rewriter); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase); + + Value result = + packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + + // Use warp shuffles to implement a layout conversion where data only needs to + // be moved within warps. + LogicalResult transferWithinWarp(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + StringAttr kReg = str_attr("register"); + StringAttr kLane = str_attr("lane"); + auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); + int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + + auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth); + auto &[pReg, pLane, mixedTranspositions, nPack] = factors; + int m = mixedTranspositions.size(); + bool pLaneIsTrivial = squareSublayoutIsIdentity(pLane, kLane); + assert((m > 0 || !pLaneIsTrivial) && "Shuffles not needed for conversion"); + + // The desired layout conversion can be expressed as a permutation P of + // hardware index bits for the `kLane` and `kReg` dimensions. The `factors` + // of P describe a decomposition + // + // P = P_mixed \circ P_lane \circ P_reg, + // + // where P_reg and P_lane are permutations involving only register or only + // lane index bits and P_mixed is a product of disjoint transpositions of + // register index bits with lane index bits. Our goal is to implement P + // using predicated selects and warp-shuffles. We have two tools for this: + // - An out-of-place `Ship` method which implements one mixed transposition + // at a time using 1.5 * R selects/permutes and .5 * R shuffles each. + // - An in-place `Swap` method which can simultaneously implement P_lane + // and multiple mixed transpositions at a time using 2 * m * R selects/ + // permutes and either (1 - (1/2)^m) * R shuffles if `pLaneIsTrivial` and + // R shuffles otherwise. + // Here, R denotes the number of 32-bit registers in use after packing (or + // splitting, if applied to 64-bit types or pointers), and in the `Swap` + // method, `m` denotes the number of mixed transpositions passed in. + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // To avoid unnecessary data movement, we remove any broadcasting in the + // register dimension from the `inVals`. + auto srcLayout = toLinearLayout(srcTy); + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout); + inVals = removeBroadcastSrc.apply(inVals); + + // If the target layout has a larger register dimension than the source + // layout, then we broadcast along the register dimension to match size. The + // removal of broadcasting above and introduction here is expected by the + // `factors`. + int regDim = inVals.size(); + int pRegDim = pReg.getInDimSize(kReg); + if (pRegDim > regDim) { + SmallVector original(inVals.begin(), inVals.end()); + inVals.clear(); + inVals.reserve(pRegDim); + while (inVals.size() < pRegDim) + inVals.append(original.begin(), original.end()); + regDim = pRegDim; + } + + // Apply pReg. + SmallVector newInVals(regDim); + for (const auto &[i, v] : llvm::enumerate(inVals)) + newInVals[pReg.apply({{kReg, i}})[0].second] = v; + inVals = std::move(newInVals); + + // Pack registers if possible. + int elemsPerVec = 1 << nPack; + int bitsPerVecElem = 32 / elemsPerVec; + if (elemsPerVec > 1) { + SmallVector packedVals; + packedVals.reserve(regDim / elemsPerVec); + if (bitwidth == 8 && bitsPerVecElem == 16) { + // TODO: Can remove `if` part of `if-else` once ptxas bugfix lands. + for (int i = 0; i < regDim; i += elemsPerVec) { + Value x0 = b.zext(i32_ty, b.bitcast(inVals[i], int_ty(bitwidth))); + Value x1 = b.zext(i32_ty, b.bitcast(inVals[i + 1], int_ty(bitwidth))); + x1 = b.shl(x1, b.i32_val(16)); + packedVals.emplace_back(b.or_(x0, x1)); + } + } else { + if (bitwidth < bitsPerVecElem) { + for (Value &v : inVals) { + if (elemTy != int_ty(bitwidth)) + v = b.bitcast(v, int_ty(bitwidth)); + v = b.zext(int_ty(bitsPerVecElem), v); + } + } + for (int i = 0; i < regDim; i += elemsPerVec) { + auto slice = ArrayRef(inVals).slice(i, elemsPerVec); + Value v = packLLVector(loc, slice, rewriter); + v = b.bitcast(v, i32_ty); + packedVals.emplace_back(v); + } + } + inVals = std::move(packedVals); + } + + auto isShippable = [](const TranspositionInfo &t) { + // The `Ship` method cannot mix elements from different registers in the + // same lane, so we are restricted to cycles like (l0 r1), (l0 r2), and + // (l0 r0 r1) which do not use both high and low register bits. + return t.topPreSel == t.topPostSel || + (t.topPreSel == 0x5140 && t.topPostSel == 0x6240) || + (t.topPreSel == 0x6420 && t.topPostSel == 0x5410) || + (t.topPreSel == 0x3210 && t.topPostSel == 0x3120); + }; + + SmallVector outVals; + if (m == 1 && pLaneIsTrivial && isShippable(mixedTranspositions[0])) { + outVals = transferWithinWarpShipImpl(loc, rewriter, inVals, nPack, + mixedTranspositions[0]); + } else { + outVals = transferWithinWarpSwapImpl(loc, rewriter, inVals, nPack, pLane, + pLaneIsTrivial, mixedTranspositions); + } + + // Unpack registers if needed. + if (elemsPerVec > 1) { + SmallVector unpackedVals; + unpackedVals.reserve(regDim); + auto packedTy = + bitwidth < bitsPerVecElem ? int_ty(bitsPerVecElem) : elemTy; + auto vecTy = vec_ty(packedTy, elemsPerVec); + auto unpackVal = [&](Value v) { + v = b.bitcast(v, vecTy); + return unpackLLVector(loc, v, rewriter); + }; + for (auto v : outVals) { + auto unpacked = unpackVal(v); + unpackedVals.append(unpacked.begin(), unpacked.end()); + } + if (bitwidth < bitsPerVecElem) { + for (Value &v : unpackedVals) { + v = b.trunc(int_ty(bitwidth), v); + if (elemTy != int_ty(bitwidth)) + v = b.bitcast(v, elemTy); + } + } + outVals = std::move(unpackedVals); + } + + // If `dstLayout` has a smaller `kReg` dimension than `srcLayout` after + // broadcasting is removed, then drop the extra registers from `outVals`. + auto dstLayout = toLinearLayout(dstTy); + auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout); + auto strippedDstLayout = removeBroadcastDst.apply(dstLayout); + outVals.resize(strippedDstLayout.getInDimSize(kReg)); + + // Introduce broadcasting in registers if expected by `dstLayout`. + if (!removeBroadcastDst.isIdentity()) + outVals = broadcastAs(outVals, dstLayout); + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + SmallVector transferWithinWarpSwapImpl( + Location loc, ConversionPatternRewriter &rewriter, ArrayRef inVals, + int nPack, const LinearLayout &pLane, bool pLaneIsTrivial, + ArrayRef mixedTranspositions) const { + auto *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringAttr kReg = str_attr("register"); + StringAttr kLane = str_attr("lane"); + + SmallVector vals(inVals.begin(), inVals.end()); + int m = mixedTranspositions.size(); + int numRegs = inVals.size(); + // A single mixed transposition (r_i l_j) which swaps the i-th register + // index bit and the j-th lane index bit of an element applies a tiled 2x2 + // block transpose with block size (1 << i) by (1 << j) to the data. This + // can be realized as: + // + // [ A B ] selp [ A D ] shfl [ A D ] selp [ A C ] + // [ C D ] ---> [ C B ] ---> [ B C ] ---> [ B D ]. + // + // In linear-algebraic terms, this is the factorization over GF(2): + // + // 1. r_i ^= l_j (selp) selp shfl selp + // 2. l_j ^= r_i (shfl) [ 0 1 ] [ 1 1 ] [ 1 0 ] [ 1 1 ] + // 3. r_i ^= l_j (selp), [ 1 0 ] = [ 0 1 ] [ 1 1 ] [ 0 1 ], + // + // where we pass in bits as column vectors [r_i, l_j]. + // + // When the transpositions are all disjoint, we can group the three stages + // of each transposition together. The two combined `selp` stages each use + // `numRegs` selects per transposition, while the `shfl` stage only requires + // code emission when at least one of the `r_i` bits is on, resulting in + // `(1 - (1/2)^m) * numRegs` shuffles in total. If `pLane` is nontrivial, + // then we can conjugate its effects through the first two stages and fuse + // it with the second stage, resulting in `numRegs` shuffles instead. + Value laneId = getLaneId(rewriter, loc); + auto pLaneInv = pLane.invert(); + const auto &pLInvBases = pLaneInv.getBases().lookup(kLane); + + // Implement r_i ^= l_j using `numRegs` independent selects or permutes. + auto applySwap = [&](TranspositionInfo t, bool preShuf) { + int rIdx = t.transposition.first - nPack; + int origLIdx = t.transposition.second; + int lIdx = preShuf ? llvm::Log2_32(pLInvBases[origLIdx][0]) : origLIdx; + uint16_t topSel = preShuf ? t.topPreSel : t.topPostSel; + uint16_t botSel = preShuf ? t.botPreSel : t.botPostSel; + + SmallVector newVals(numRegs); + Value lBitVal = b.and_(laneId, b.i32_val(1 << lIdx)); + Value lBitOff = b.icmp_eq(lBitVal, b.i32_val(0)); + + int tileSize = 1 << (rIdx + 1); + int numTiles = numRegs / tileSize; + for (int tileIdx = 0; tileIdx < numTiles; ++tileIdx) { + int baseIdx = tileIdx * tileSize; + for (int i = 0; i < tileSize / 2; ++i) { + int r0 = baseIdx + i; + int r1 = r0 + (1 << rIdx); + Value v0 = vals[r0]; + Value v1 = vals[r1]; + if (topSel == 0x3210 && botSel == 0x7654) { + newVals[r0] = b.select(lBitOff, v0, v1); + newVals[r1] = b.select(lBitOff, v1, v0); + } else { + Value sel00 = b.i32_val(topSel); + Value sel01 = b.i32_val(preShuf ? botSel : (topSel ^ 0x4444)); + Value sel10 = b.i32_val(botSel); + Value sel11 = b.i32_val(preShuf ? topSel : (botSel ^ 0x4444)); + Value sel1 = b.select(lBitOff, sel00, sel01); + Value sel2 = b.select(lBitOff, sel10, sel11); + newVals[r0] = targetInfo.permute(rewriter, loc, v0, v1, sel1); + newVals[r1] = targetInfo.permute(rewriter, loc, v0, v1, sel2); + } + } + } + return newVals; + }; + + // Stage 1 (selp/prmt) + for (const auto &t : mixedTranspositions) + vals = applySwap(t, /*preShuf=*/true); + // Stage 2 (shfl) + Value laneIdPerm; + if (!pLaneIsTrivial) + laneIdPerm = triton::gpu::matrixVectorProd(b, pLaneInv, laneId); + for (int r = 0; r < numRegs; ++r) { + int mask = 0; + for (const auto &t : mixedTranspositions) { + int rIdx = t.transposition.first - nPack; + int lIdx = t.transposition.second; + if (r & (1 << rIdx)) { + mask |= pLInvBases[lIdx][0]; + } + } + if (pLaneIsTrivial) { + if (mask != 0) + vals[r] = targetInfo.shuffleXor(rewriter, loc, vals[r], mask); + } else { + Value srcIdx = b.xor_(laneIdPerm, b.i32_val(mask)); + vals[r] = targetInfo.shuffleIdx(rewriter, loc, vals[r], srcIdx); + } + } + // Stage 3 (selp/prmt) + for (const auto &t : mixedTranspositions) + vals = applySwap(t, /*preShuf=*/false); + return vals; + } + + SmallVector + transferWithinWarpShipImpl(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef inVals, int nPack, + TranspositionInfo t) const { + // Implements the effects of a single mixed transposition as in + // `transferWithinWarpSwapImpl`, but uses auxiliary registers to hold the + // values to be shuffled, resulting in fewer emitted instructions. + int numRegs = inVals.size(); + int rIdx = t.transposition.first - nPack; + int lIdx = t.transposition.second; + int tileSize = 1 << (rIdx + 1); + int numTiles = numRegs / tileSize; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value laneId = getLaneId(rewriter, loc); + Value lBitVal = b.and_(laneId, b.i32_val(1 << lIdx)); + Value lBitOff = b.icmp_eq(lBitVal, b.i32_val(0)); + SmallVector outVals(numRegs); + + auto shipDiagSels = [](auto postSel) { + if (postSel == 0x3120) + return std::pair{0x7564, 0x7564}; + auto high = (postSel & 0x4444) >> 2; + auto sel10 = postSel ^ ((postSel & 0x1000) ? high << 1 : high); + return std::pair{sel10, sel10 ^ 0x4444}; + }; + + for (int tileIdx = 0; tileIdx < numTiles; ++tileIdx) { + int baseIdx = tileIdx * tileSize; + for (int i = 0; i < tileSize / 2; ++i) { + int r0 = baseIdx + i; + int r1 = r0 + (1 << rIdx); + Value v0 = inVals[r0]; + Value v1 = inVals[r1]; + if (t.topPreSel == 0x3210 && t.topPostSel == 0x3210) { + Value valToShip = b.select(lBitOff, v1, v0); + Value shippedVal = + targetInfo.shuffleXor(rewriter, loc, valToShip, (1 << lIdx)); + outVals[r0] = b.select(lBitOff, v0, shippedVal); + outVals[r1] = b.select(lBitOff, shippedVal, v1); + } else { + Value shipSel = + b.select(lBitOff, b.i32_val(t.botPreSel), b.i32_val(t.topPreSel)); + Value valToShip = targetInfo.permute(rewriter, loc, v0, v1, shipSel); + Value shippedVal = + targetInfo.shuffleXor(rewriter, loc, valToShip, (1 << lIdx)); + Value sel00 = b.i32_val(t.topPostSel); + Value sel01 = b.i32_val(shipDiagSels(t.topPostSel).second); + Value sel10 = b.i32_val(shipDiagSels(t.topPostSel).first); + Value sel11 = b.i32_val(t.botPostSel ^ 0x4444); + Value sel1 = b.select(lBitOff, sel00, sel01); + Value sel2 = b.select(lBitOff, sel10, sel11); + outVals[r0] = targetInfo.permute(rewriter, loc, v0, shippedVal, sel1); + outVals[r1] = targetInfo.permute(rewriter, loc, v1, shippedVal, sel2); + } + } + } + return outVals; + } +}; + +} // namespace + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 0000000000..20504b069e --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,74 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + Type tgtTy = accum.getType(); + auto castToTargetType = [&](Value v) -> Value { + if (v.getType() == tgtTy) + return v; + if (isa(tgtTy) && isa(v.getType())) { + auto srcTy = cast(v.getType()); + auto dstTy = cast(tgtTy); + if (srcTy.getWidth() < dstTy.getWidth()) + return LLVM::FPExtOp::create(builder, loc, tgtTy, v); + return LLVM::FPTruncOp::create(builder, loc, tgtTy, v); + } + if (isa(tgtTy) && isa(v.getType())) { + auto srcTy = cast(v.getType()); + auto dstTy = cast(tgtTy); + if (srcTy.getWidth() < dstTy.getWidth()) + return LLVM::SExtOp::create(builder, loc, tgtTy, v); + return LLVM::TruncOp::create(builder, loc, tgtTy, v); + } + llvm_unreachable("unsupported type conversion in FMA dot lowering"); + }; + for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) { + Value aElem = castToTargetType(std::get<0>(*it)); + Value bElem = castToTargetType(std::get<1>(*it)); + + // to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM + // type or LLVM dialect-compatible vector of floating point LLVM type, but + // got 'i32' + llvm::TypeSwitch(tgtTy) + .Case([&](auto) { + accum = LLVM::FMulAddOp::create(builder, loc, aElem, bElem, accum); + }) + .Case([&](auto) { + accum = LLVM::AddOp::create( + builder, loc, LLVM::MulOp::create(builder, loc, aElem, bElem), + accum); + }); + } + return accum; + } +}; + +} // namespace + +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 0000000000..fa2c814722 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,170 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + res[key] = elems[idx]; + } + return res; +} + +} // namespace + +namespace mlir::triton::gpu { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getContigPerThread(dTensorTy); + auto numElemsPerThread = product(sizePerThread); + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), + dLayout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); + sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 0000000000..2119da6424 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,752 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); + } + return numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = b.gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = b.gep(resultPtrTy, resultElemTy, adaptor.getPtr(), + adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {LLVM::ICmpOp::create(rewriter, loc, elemTy, + ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {LLVM::FCmpOp::create(rewriter, loc, elemTy, + ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + if (funcName.empty()) { + auto intTy = cast(resultElementTy); + unsigned bitWidth = intTy.getWidth(); + Type wideTy = IntegerType::get(rewriter.getContext(), bitWidth * 2); + Value lhsWide = + LLVM::ZExtOp::create(rewriter, loc, wideTy, operands[0][0]); + Value rhsWide = + LLVM::ZExtOp::create(rewriter, loc, wideTy, operands[0][1]); + Value prodWide = + LLVM::MulOp::create(rewriter, loc, wideTy, lhsWide, rhsWide); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value shift = b.int_val(bitWidth * 2, bitWidth); + Value highWide = + LLVM::LShrOp::create(rewriter, loc, wideTy, prodWide, shift); + return {LLVM::TruncOp::create(rewriter, loc, resultElementTy, highWide)}; + } + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = std::max(32 / bitWidth, 1u); + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = b.undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = b.insert_element(packed, operands[j + k][i], b.i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty); + unsigned numElemsPerReg = + std::min(std::max(32 / bitWidth, 1u), op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = LLVM::InlineAsmOp::create( + rewriter, loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, LLVM::TailCallKind::None, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + int structIdx = 0; + for (int i = 0; i < op->getNumResults(); i++) { + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = b.extract_val(asmResults, structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(b.extract_element(val, b.i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back(subOperands); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(b.undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {LLVM::AbsOp::create(rewriter, loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = LLVM::ConstantOp::create(rewriter, loc, maskAttr); + return {b.and_(operands[0][0], maskConst)}; + } + + return {LLVM::FAbsOp::create(rewriter, loc, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {LLVM::SelectOp::create(rewriter, loc, llvmOperands[1].getType(), + llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {DestOpNanProp::create(rewriter, loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = LLVM::OrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); + auto nonNanRes = DestOpNoNanProp::create(rewriter, loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {LLVM::SelectOp::create(rewriter, loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = LLVM::MaximumOp::create(rewriter, loc, elemTy, operands[0][0], + operands[0][1]); + return {LLVM::MinimumOp::create(rewriter, loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = LLVM::FCmpOp::create(rewriter, loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = LLVM::MaxNumOp::create(rewriter, loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = LLVM::MinNumOp::create(rewriter, loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {LLVM::SelectOp::create(rewriter, loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = LLVM::MaxNumOp::create(rewriter, loc, elemTy, operands[0][0], + operands[0][1]); + return {LLVM::MinNumOp::create(rewriter, loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct MapElementwiseOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + using Adaptor = typename Base::OpAdaptor; + + using Base::Base; + + LogicalResult matchAndRewrite(MapElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + + auto operands = adaptor.getOperands(); + const auto nOperands = operands.size(); + const auto nElems = + cast(operands[0].getType()).getBody().size(); + const auto nElemsPerPack = op.getPack(); + if (nElems % nElemsPerPack != 0) + return op->emitError() + << "pack size must be a divisor of the number of elements per " + "thread, but got pack = " + << nElemsPerPack << ", elements per thread = " << nElems << "\n"; + + const auto nPacks = nElems / nElemsPerPack; + auto nArgsUnpacked = nElemsPerPack * nOperands; + + SmallVector scalarOperands(nOperands * nElems); + for (auto iOp : llvm::seq(nOperands)) { + auto elems = unpackLLElements(loc, operands[iOp], rewriter); + assert(elems.size() == nElems); + for (auto iPack : llvm::seq(nPacks)) { + auto *packOperands = + &scalarOperands[iPack * nArgsUnpacked + iOp * nElemsPerPack]; + auto *packElems = &elems[iPack * nElemsPerPack]; + for (auto iElem : llvm::seq(nElemsPerPack)) { + packOperands[iElem] = packElems[iElem]; + } + } + } + + auto &scalarOp = op.getScalarOp(); + Region &parent = *rewriter.getBlock()->getParent(); + + auto nOutputs = op.getNumResults(); + SmallVector scalarOutputs(nOutputs * nElems); + for (auto iPack : llvm::seq(nPacks)) { + ArrayRef packedArgs(&scalarOperands[iPack * nArgsUnpacked], + nArgsUnpacked); + auto packResults = inlineRegion( + rewriter, scalarOp, packedArgs, loc); + assert(packResults.size() == nOutputs * nElemsPerPack); + for (auto iOut : llvm::seq(nOutputs)) { + auto *packOutputs = + &scalarOutputs[iOut * nElems + iPack * nElemsPerPack]; + for (auto iElem : llvm::seq(nElemsPerPack)) { + packOutputs[iElem] = packResults[iOut * nElemsPerPack + iElem]; + } + } + } + + SmallVector packedOutputs(nOutputs); + for (auto iOut : llvm::seq(nOutputs)) { + ArrayRef vals(&scalarOutputs[iOut * nElems], nElems); + packedOutputs[iOut] = + packLLElements(loc, typeConverter, vals, rewriter, op.getType(iOut)); + } + rewriter.replaceOp(op, packedOutputs); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 0000000000..33362bf8fd --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,149 @@ +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +// NOTE: [Additional Function Arguments] +// Triton patches additional arguments to the function signature to support +// (1) shared memory, (2) global scratch memory, and (3) profile scratch memory. +// To support use of shared memory and global scratch memory inside of a +// function, the caller allocates a single large block of the relevant memory +// and calls the function with these extra arguments at the end. +// Profile scratch memory is only used when the function is instrumented for +// profiling. +// +// For the kernel function itself, the shared memory base is a global symbol +// so no additional function argument is required but global scratch memory +// allocation is still passed in as the last argument. Though here the scratch +// memory is shared between all programs, so a linear offset based on the +// program id is required to get the local scratch base. + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = triton::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { + continue; + } + + for (const auto &attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); + + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(), + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(), + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(), + mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); + + FailureOr maybeNewFuncOp = + mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, + *getTypeConverter()); + if (failed(maybeNewFuncOp)) { + return failure(); + } + + LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp; + handleArgPtrDatatype(funcOp, newFuncOp); + + auto ctx = funcOp->getContext(); + + if (triton::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(), + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + + // Determine the actual number of required warps. + int numWarps = triton::gpu::lookupNumWarps(funcOp); + if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType( + "ttg.total-num-warps")) + numWarps = totalNumWarps.getInt(); + + int numCTAs = 1; + if (auto module = funcOp->getParentOfType()) { + if (auto moduleAttr = + module->getAttrOfType(triton::gpu::AttrNumCTAsName)) + numCTAs = moduleAttr.getInt(); + } + + // Set `nvvm.maxnreg` if it was specified on the module. + if (Attribute maxnregAttr = + funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName)) + newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr); + + // Do we want to do this for nCTAs == 1 whenever sm >= 90? + if (numCTAs > 1) { + // Request a specific number of CTAs per cluster in the generated PTX. + newFuncOp->setAttr(NVVM::NVVMDialect::getClusterDimAttrName(), + rewriter.getDenseI32ArrayAttr(numCTAs)); + } + + // Set an attribute for reqntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(), + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + + rewriter.eraseOp(funcOp); + rewriter.eraseOp(amendedFuncOp); + + // Add attributes for by-value TMA descriptor args (nvidia) + handleByvalTmaDescArgs(newFuncOp); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp new file mode 100644 index 0000000000..7c87c6056b --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -0,0 +1,349 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class GatherOpConversion : public ConvertOpToLLVMPattern { +public: + GatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + const TargetInfoBase &targetInfo; +}; + +LogicalResult +GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = b.trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = b.zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + + // Compute the src subtensor shape owned by this CTA. + SmallVector srcShapePerCTA = + convertType(triton::gpu::getShapePerCTA(srcType)); + + // Grab the src values in this thread. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // Emit the indices of the src values owned by this thread. + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(), + op.getSrc().getType(), /*withCTAOffset=*/true); + + // Store the src values owned by the thread into their respective location in + // the scratch memory. + assert(srcValues.size() == srcIndices.size()); + + // Get the base pointer to the scratch memory. + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // For each src element owned by the thread, index into the scratch memory and + // then store it. + Type elemType = getTypeConverter()->convertType(srcType.getElementType()); + for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) { + // Convert the index at each dim into a single offset given the shape of the + // tensor. + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + // Emit the offset into the shared memory and then store the value. + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + b.store(value, ptr); + } + + // Synchronize the whole CTA. + b.barrier(triton::gpu::AddrSpace::Local); + + // Grab the index values owned by this thread. + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // Apply the layout of the destination tensor to obtain the indices of the + // column to gather along, then for each column, replace the index along the + // gather axis with the appropriate index value. + // + // I = LL(pid) + // idx = indices[I] + // I_gather = [I[d] if d != axis else idx for d in range(len(I))] + // out[I] = src[I_gather] + RankedTensorType dstType = op.getType(); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, + /*withCTAOffset=*/true); + + unsigned axis = op.getAxis(); + SmallVector results(dstIndices.size()); + for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + indices[axis] = convertIndexToI32(loc, idx, rewriter); + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + results[i] = b.load(elemType, ptr); + } + + Value packed = + packLLElements(loc, getTypeConverter(), results, rewriter, dstType); + rewriter.replaceOp(op, packed); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = toLinearLayout(srcType); + LinearLayout idxLayout = toLinearLayout(idxType); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Invert the source layout. It doesn't matter whether it is fully invertible + // with respect to anything except the register input dimension, since we know + // those don't vary in ways that matter for codegen. + LinearLayout invSrcLayout = srcLayout.pseudoinvert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kWarp, kBlock}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kRegister, kLane}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = targetInfo.getClusterCTAId(rewriter, loc); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + invertSrcRegMap = invertSrcRegMap.removeZeroBasesAlongDim(kGatherDim); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kRegister, b.i32_val(idxReg)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.insert(column.begin() + op.getAxis(), + {kGatherDim, convertIndexToI32(loc, idxVal, rewriter)}); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = b.undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = b.select(b.icmp_eq(b.i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); +} + +} // namespace + +void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.insert(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp new file mode 100644 index 0000000000..1abb9b7281 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp @@ -0,0 +1,105 @@ +#include "mlir/Analysis/Liveness.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUGLOBALSCRATCHALLOCATIONPASS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +static int32_t roundUp(int32_t val, int32_t step) { + auto t = val + step - 1; + return t - (t % step); +} + +static void allocateGMem(Operation *parentOp, + llvm::SetVector &callStack) { + // Recursively visit any dependency functions + parentOp->walk([&](triton::CallOp call) { + auto callable = call.resolveCallable(); + if (!callable->hasAttr("ttg.global_scratch_memory_size")) { + auto inserted = callStack.insert(parentOp); + assert(inserted && "call cycle detected"); + allocateGMem(callable, callStack); + callStack.remove(parentOp); + } + }); + + MLIRContext *ctx = parentOp->getContext(); + OpBuilder builder(ctx); + int32_t offset = 0; + uint32_t largestAlignment = 1; + + // Dumb allocation that ignores liveness and makes no attempt to minimize + // padding + // TODO: Use a real algorithm + parentOp->walk([&](Operation *op) { + uint32_t nbytes = 0; + uint32_t align = 0; + if (auto alloc = dyn_cast(op)) { + if (alloc.getBackend() != "default") + return; + nbytes = alloc.getNbytes(); + align = alloc.getAlignment(); + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto nbytes_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_size"); + auto align_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(nbytes_attr); + assert(align_attr); + + nbytes = nbytes_attr.getValue().getZExtValue(); + align = align_attr.getValue().getZExtValue(); + } + if (nbytes > 0) { + offset = roundUp(offset, align); + op->setAttr("ttg.global_scratch_memory_offset", + builder.getI32IntegerAttr(offset)); + offset += nbytes; + largestAlignment = std::max(largestAlignment, align); + } + }); + int32_t totalMemorySize = roundUp(offset, largestAlignment); + parentOp->setAttr("ttg.global_scratch_memory_size", + builder.getI32IntegerAttr(totalMemorySize)); + parentOp->setAttr("ttg.global_scratch_memory_alignment", + builder.getI32IntegerAttr(largestAlignment)); +} + +namespace { +class TritonGPUGlobalScratchAllocationPass + : public mlir::triton::gpu::impl::TritonGPUGlobalScratchAllocationPassBase< + TritonGPUGlobalScratchAllocationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + + bool seenKernel = false; + + SetVector callStack; + mod->walk([&](triton::FuncOp func) { + allocateGMem(func, callStack); + + if (func.getVisibility() == SymbolTable::Visibility::Public) { + assert(!seenKernel); + seenKernel = true; + auto size = + func->getAttrOfType("ttg.global_scratch_memory_size"); + auto align = func->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(size); + assert(align); + mod->setAttr("ttg.global_scratch_memory_size", size); + mod->setAttr("ttg.global_scratch_memory_alignment", align); + } + }); + assert(seenKernel); + } +}; +} // namespace diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 0000000000..571308c860 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,225 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + SmallVector &maskValues, int numBins, int numThreadPerWarp, + Value threadId, ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = b.i32_val(0); + int numBits = llvm::Log2_64(numBins); + int numBitsLaneId = llvm::Log2_64(numThreadPerWarp); + unsigned numElementsPerThreads = getTotalElemsPerThread(srcType); + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = b.and_(value, b.i32_val(1 << j)); + Value cmp = b.icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = + b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero), + b.int_val(numThreadPerWarp, 0), fullMask); + mask = b.and_( + mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // save a ballot bit to capture the input mask + Value inputMaskBit = fullMask; + if (maskValues.size() > 0) { + inputMaskBit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), + maskValues[i]); + } + // mask out the values for which input mask is invalid + mask = b.and_(mask, inputMaskBit); + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + b.int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = b.and_(binMask, b.xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = LLVM::CtPopOp::create(rewriter, loc, + int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = b.trunc(i32_ty, bitCount); + warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + LLVM::AtomicRMWOp::create(rewriter, loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector histogramValues; + Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = + b.add(threadId, b.i32_val((i * numWarps * numThreadPerWarp))); + offset = b.urem(offset, b.i32_val(numBins)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + b.store(b.i32_val(0), sharedMemPtr); + } + b.barrier(triton::gpu::AddrSpace::Local); + Block *afterAtomics = nullptr; + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = b.add(b.mul(laneId, b.i32_val(warpLevelHistogram.size())), + b.i32_val(i)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + LLVM::BrOp::create(rewriter, loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + b.barrier(triton::gpu::AddrSpace::Local); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = b.load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + + Value llMask = adaptor.getMask(); + SmallVector maskValues; + if (llMask) + maskValues = unpackLLElements(loc, llMask, rewriter); + + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::lookupNumWarps(op); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp, + threadId, rewriter, targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + // Depending on the layout, some threads may have duplicate data. We can + // account for this by calculating a "replication factor" and dividing the + // results by it to avoid overcounting. + auto replicationFactor = numWarps * numThreadsPerWarp; + auto threadsPerWarp = getThreadsPerWarp(srcType); + auto warpsPerCTA = + getWarpsPerCTA(srcType.getEncoding(), srcType.getShape()); + replicationFactor /= std::accumulate( + threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>()); + replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(), + 1, std::multiplies<>()); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (auto i = 0; i < histogramValue.size(); ++i) { + histogramValue[i] = + b.sdiv(histogramValue[i], b.i32_val(replicationFactor)); + } + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 0000000000..8060b44312 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = b.add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 0000000000..4c2b917e0b --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,463 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Helper for LocalGather/ScatterOpConversion. +// For gather: storeVals is empty, returns loaded values. +// For scatter: storeVals contains values to store, returns empty. +SmallVector lowerLocalScGt(Location loc, MLIRContext *ctx, + MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, + unsigned axis, ArrayRef storeVals, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + bool isScatter = !storeVals.empty(); + + // Get the shared memory layout (linear component for padded layouts) + auto sharedEnc = + cast(memDescTy.getEncoding()); + auto paddedEnc = dyn_cast(sharedEnc); + LinearLayout sharedLayout; + if (paddedEnc) { + sharedLayout = paddedEnc.getLinearComponent(); + } else { + sharedLayout = toLinearLayout(memDescTy); + } + LinearLayout invSharedLayout = sharedLayout.invert(); + + // Get layout dimension names for all dims + SmallVector allDims; + for (unsigned dim = 0, rank = memDescTy.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + } + + auto kOffset = str_attr("offset"); + + // Get the subslice affine offset (non-zero for memdesc subslices) + Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + + SmallVector results; + if (!isScatter) { + results.resize(coords.size()); + } + + for (auto [i, idxVal] : llvm::enumerate(idxValues)) { + // Convert index to i32 if needed + Value idx = idxVal; + unsigned idxWidth = idx.getType().getIntOrFloatBitWidth(); + if (idxWidth > 32) { + idx = b.trunc(i32_ty, idx); + } else if (idxWidth < 32) { + idx = b.zext(i32_ty, idx); + } + + // Copy coordinates and replace the axis coordinate with the index value + SmallVector indices(coords[i]); + indices[axis] = idx; + + // Apply inverted shared layout to compute offset + SmallVector> inputs; + for (unsigned dim = 0; dim < indices.size(); ++dim) { + inputs.push_back({allDims[dim], indices[dim]}); + } + + auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs); + + // Extract the offset value + Value offset = nullptr; + for (auto [name, value] : outputs) { + if (name == kOffset) { + offset = value; + break; + } + } + assert(offset && "expected offset output from inverted shared layout"); + + // For subslices, the physical offset is computed as: + // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) + // + // We use XOR for consistency with lowerLdSt. MemDescSubsliceOp::verify() + // enforces: + // 1. Subslice offsets must be multiples of the tile size + // 2. Subslice offsets must map to power-of-2 physical offsets + // + // These constraints ensure the bit ranges of L⁻¹(coords) and + // L⁻¹(subslice_offset) are disjoint, so XOR and addition are equivalent. + offset = b.xor_(offset, affineOffset); + + // Add padding offset for padded layouts (non-linear component) + Value ptr; + if (paddedEnc) { + // Convert offset to bytes for padding calculation + Value offsetBytes = b.mul(offset, b.i32_val(bitwidth / 8)); + auto shifts = getPaddedSharedShifts(paddedEnc, bitwidth, + /*offsetInBytes=*/true); + // GEP in bytes: base + offset*elemSize + padOffset + Value totalOffset = applyPadding(loc, rewriter, offsetBytes, shifts); + ptr = b.gep(smemObj.getBase().getType(), i8_ty, smemObj.getBase(), + totalOffset); + } else { + ptr = b.gep(smemObj.getBase().getType(), llvmElemTy, smemObj.getBase(), + offset); + } + + if (isScatter) { + b.store(storeVals[i], ptr); + } else { + results[i] = b.load(llvmElemTy, ptr); + } + } + + return results; +} + +LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal, + MemDescType memDescTy, SharedMemoryObject smemObj, + ArrayRef inVals, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto regTy = cast(regVal.getType()); + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kOffset = str_attr("offset"); + auto regLayout = toLinearLayout(regTy); + auto paddedEnc = + dyn_cast(memDescTy.getEncoding()); + LinearLayout cvt = LinearLayout::empty(); + if (paddedEnc) { + const auto &sharedLL = paddedEnc.getLinearComponent(); + cvt = regLayout.invertAndCompose(sharedLL); + } else { + auto sharedLayout = toLinearLayout(memDescTy); + cvt = regLayout.invertAndCompose(sharedLayout); + } + auto kBlock = str_attr("block"); + // NYI. We would need to emit a map.shared::cluster instruction. + if (!cvt.isTrivialOver({kBlock})) { + return failure(); + } + cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset}); + lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj, + rewriter, targetInfo); + + return success(); +} + +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase *targetInfo; + + GlobalScratchAllocOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(&targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto opOffsetAttr = op->getAttrOfType( + "ttg.global_scratch_memory_offset"); + assert(opOffsetAttr); + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) { + return failure(); + } + Value ptr = LLVM::getGlobalScratchPtr(loc, rewriter, *targetInfo, funcOp, + b.i32_val(opOffset)); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto memDescTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, memDescTy.getRank(), + loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + auto *ctx = op.getContext(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj, + inVals, typeConverter, rewriter, + targetInfo))) { + return failure(); + } + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto memDescVal = op.getSrc(); + auto regVal = op.getResult(); + auto memDescTy = cast(memDescVal.getType()); + auto regTy = cast(regVal.getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + + auto sharedEnc = + cast(memDescTy.getEncoding()); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kOffset = str_attr("offset"); + auto regLayout = toLinearLayout(regTy); + auto paddedEnc = dyn_cast(sharedEnc); + LinearLayout cvt = LinearLayout::empty(); + if (paddedEnc) { + const auto &sharedLL = paddedEnc.getLinearComponent(); + cvt = regLayout.invertAndCompose(sharedLL); + } else { + auto sharedLayout = toLinearLayout(memDescTy); + cvt = regLayout.invertAndCompose(sharedLayout); + } + auto kBlock = str_attr("block"); + // NYI. We would need to emit a map.shared::cluster instruction. + if (!cvt.isTrivialOver({kBlock})) { + return failure(); + } + cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset}); + + auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy, + smemObj, rewriter, targetInfo, op); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + Value regVal = op.getSrc(); + Value memDescVal = op.getDst(); + auto typeConverter = getTypeConverter(); + auto memDescTy = cast(memDescVal.getType()); + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(), + llvmElemTy, rewriter); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals, + typeConverter, rewriter, targetInfo))) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +class BarrierOpConversion + : public ConvertOpToLLVMPattern { +public: + BarrierOpConversion(const LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + using OpAdaptor = typename triton::gpu::BarrierOp::Adaptor; + + LogicalResult + matchAndRewrite(triton::gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { +public: + LocalGatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(LocalGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto memDescTy = cast(op.getSrc().getType()); + auto regTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, + /*withCTAOffset=*/true); + + auto results = lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, + idxValues, dstIndices, op.getAxis(), + /*storeVals=*/{}, rewriter); + + Value result = packLLElements(loc, typeConverter, results, rewriter, regTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalScatterOpConversion + : public ConvertOpToLLVMPattern { +public: + LocalScatterOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(LocalScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto memDescTy = cast(op.getDst().getType()); + auto valuesTy = cast(op.getValues().getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(), + llvmElemTy, rewriter); + + SmallVector values = + unpackLLElements(loc, adaptor.getValues(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, valuesTy.getEncoding(), valuesTy, + /*withCTAOffset=*/true); + + lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, idxValues, + srcIndices, op.getAxis(), values, rewriter); + + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} + +void mlir::triton::populateBarrierOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 0000000000..e17b0e3ad3 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,243 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + std::array pid; + auto module = op->getParentOfType(); + for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z}) + pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis); + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, {}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + assert(op.getNumOperands() == op.getIsSigned().size()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + bool isSigned = op.getIsSigned()[i] > 0; + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter, isSigned); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter, bool isSigned) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + + os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + auto isSignedOperands = + llvm::SmallVector(printfOperands.size(), isSigned); + if (i == 0) { + formatStrValue = llPrintf(formatStr, printfOperands, isSignedOperands, + rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands, isSignedOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt, + bool isSigned = false) const { + Type type = value.getType(); + // If the `value` is a pointer, just return %p. + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + (isSigned ? "lli" : "llu"); + else + return prefix + (isSigned ? "i" : "u"); + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, ArrayRef isSigned, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, + isSigned); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 0000000000..bb741dbfa7 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,391 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DistributedEncodingTrait; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isReduceWithinCTA() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchRepShape(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter, targetInfo); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); + } + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + b.barrier(triton::gpu::AddrSpace::Local); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumulate over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); + if (success) + return; + + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = + mlir::cast(helper.getSrcLayout()); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchRepShape(); + + // Lezcano: We should move all the shared memory logic to use LLs natively + auto srcShape = helper.getSrcShape(); + auto kLane = rewriter.getStringAttr("lane"); + auto [multiDimLaneId, isRepresentativeLane] = + delinearize(rewriter, loc, srcLayout, srcShape, kLane, laneId); + auto kWarp = rewriter.getStringAttr("warp"); + auto [multiDimWarpId, isRepresentativeWarp] = + delinearize(rewriter, loc, srcLayout, srcShape, kWarp, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value laneZero = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value write = + b.and_(b.and_(isRepresentativeLane, isRepresentativeWarp), laneZero); + + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], write); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto smemShape = helper.getScratchRepShape(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto mod = op->getParentOfType(); + int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numWarps = triton::gpu::lookupNumWarps(op); + int numThreads = numLanes * numWarps; + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = b.i32_val(numLanes); + Value laneId = b.urem(threadId, warpSize); + Value zero = b.i32_val(0); + + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = b.icmp_slt(threadId, b.i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = b.urem(laneId, b.i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + b.icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = b.and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = b.add(readOffset, b.i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), b.i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller than src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + b.urem(readIdx[smemIdx], b.i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + resultVals[j] = b.load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = b.load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 0000000000..b132461761 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,163 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" + +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an uninitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = LLVM::UndefOp::create(rewriter, loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + LLVM::CondBrOp::create(rewriter, loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, results, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + auto basePtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + indexToBase[indices[0]] = basePtr; + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = + b.gep(basePtr.getType(), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], b.i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 0000000000..13b4f018f7 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,37 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId( + rewriter, op->getLoc(), op->getParentOfType(), op.getAxis()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 0000000000..8cfb811b86 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,585 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::toLinearEncoding; + +// apply combine region to acc and cur and accumulate it into acc +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + Value mask = b.icmp_sge(laneIdAxis, b.i32_val(i)); + SmallVector tempAcc = + accumulate(helper, rewriter, shfl, acc, mask); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = b.select(mask, tempAcc[j], acc[j]); + } + } + srcValues[srcIndex] = std::move(acc); + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, Value isRepresentative, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = b.icmp_eq(laneId, b.i32_val(scanDim - 1)); + mask = b.and_(mask, isRepresentative); + Value index = + b.add(parallelLaneId, b.mul(warpId, b.i32_val(numParallelLane))); + index = b.add(index, b.i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = + b.gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + ArrayRef smemBases, + ArrayRef smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskNotFirstWarp = b.icmp_ne(warpId, b.i32_val(0)); + Value maskNotFirstLane = b.icmp_ne(laneIdAxis, b.i32_val(0)); + Value maskNotFirstThread = b.or_(maskNotFirstWarp, maskNotFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = + b.add(parallelLaneId, + b.i32_val(numParallelLane * (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = b.gep(smemBases[j].getType(), elemTy, smemBases[j], index); + partialReduce[j] = b.load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + Value mask = b.icmp_sge(warpId, b.i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + b.select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); + } + } + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = b.select(maskNotFirstWarp, temp[i], val[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = + b.select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = b.select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = b.icmp_eq(warpId, b.i32_val(0)); + Value maskFirstLane = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value maskFirstThread = b.and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = + b.select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = b.select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], + laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter, targetInfo))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + std::tuple, Value> + getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId) const; + std::tuple, Value> + getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const; +}; + +std::tuple, Value> +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("lane"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + laneId); +} + +std::tuple, Value> +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("warp"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + warpId); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = srcEncoding.getThreadsPerWarp(); + auto warpsPerCTA = srcEncoding.getWarpsPerCTA(); + auto [multiDimLaneId, isRepresentativeLane] = + getMultiDimLaneId(rewriter, helper, laneId); + auto [multiDimWarpId, isRepresentativeWarp] = + getMultiDimWarpId(rewriter, helper, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = b.i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = linearize(rewriter, loc, multiDimLaneId, + threadsPerWarp, helper.getOrder()); + multiDimWarpId[axis] = b.i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, helper.getOrder()); + Value flatIdParallel = b.add( + laneIdParallel, + b.mul(warpIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerWarp()))); + auto isRepresentative = b.and_(isRepresentativeLane, isRepresentativeWarp); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel, + isRepresentative); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, unsigned nElems, + const ColumnAction &removeBroadcastRegs) { + auto operands = adaptor.getOperands(); + SmallVector> srcValues(nElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + if (!removeBroadcastRegs.isIdentity()) { + values = removeBroadcastRegs.apply(values); + } + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + ScanLoweringHelper helper(op); + auto origLayout = triton::gpu::toLinearLayout( + cast(op.getOperands()[0].getType())); + auto removeBroadcastRegs = actionRemoveBroadcastedRegs(origLayout); + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!helper.isSupported()) + return op.emitError("TODO: unsupported scan layout"); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = b.i32_val(iWarpSize); + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel, isRepresentative] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned nElems = + helper.getEncoding().getTotalElemsPerThread(helper.getShape()); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, nElems, removeBroadcastRegs); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = b.sub(b.i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = + getSmemBases(op, elems, rewriter, targetInfo); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, isRepresentative, + targetInfo); + b.barrier(triton::gpu::AddrSpace::Local); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = + std::get<0>(getMultiDimLaneId(rewriter, helper, laneId)); + multiDimLaneId[helper.getAxis()] = b.i32_val(scanDim - 1); + auto linearEncoding = helper.getEncoding(); + auto kLane = StringAttr::get(rewriter.getContext(), "lane"); + Value laneIdLast = + linearize(rewriter, loc, multiDimLaneId, linearEncoding, kLane); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + if (!removeBroadcastRegs.isIdentity()) { + for (auto &values : valuesTransposed) { + values = broadcastAs(values, origLayout); + } + } + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 0000000000..f220ad3175 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,77 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis) + : TritonGPUToLLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), targetInfo, + analysis) {} + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const LowerToLLVMOptions &options, + const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, options, analysis) { + addConversion([ctx](triton::PointerType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); + }); + addConversion([ctx](TensorDescType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, 0); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type, targetInfo); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type, targetInfo); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncTokenType(type); + }); + + convertFP8Type(); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + Type eltType = convertType(type.getElementType()); + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType( + MemDescType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + // base ptr + auto ptrType = LLVM::LLVMPointerType::get( + ctx, targetInfo.getAddressSpace(type.getMemorySpace())); + + if (isa( + type.getEncoding())) { + return ptrType; + } + + SmallVector types; + types.push_back(ptrType); + auto rank = type.getRank(); + // offsets + for (auto i = 0; i < rank; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncTokenType( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 0000000000..73d4cb25b9 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,1732 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +#include + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_clz(unsigned x) { + unsigned long r; + _BitScanReverse(&r, x); + return static_cast(r ^ 31); +} + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +#endif + +namespace mlir { + +namespace triton::gpu { + +std::pair, SmallVector> +getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth) { + assert(bitwidth <= 128 && "bitwidth must be <= 128"); + assert(llvm::isPowerOf2_32(bitwidth) && "bitwidth must be a power of two"); + SmallVector src; + SmallVector dst; + + // ld.shared/st.shared + auto ldstshared = LocalMemOpTile{{}, {0, 1, 2}}; + src.push_back(ldstshared); + dst.push_back(ldstshared); + + if (targetInfo.supportLdMatrix() || targetInfo.supportStMatrix()) { + // ldmatrix/stmatrix + if (bitwidth <= 32) { + auto ldstmatrix = LocalMemOpTile{{0, 1}, {2, 3, 4}}; + if (targetInfo.supportStMatrix()) { + src.push_back(ldstmatrix); + } + if (targetInfo.supportLdMatrix()) { + dst.push_back(ldstmatrix); + } + } + // ldmatrix.trans/stmatrix.trans + if (bitwidth == 16) { + auto ldstmatrixtrans = LocalMemOpTile{{2, 3, 4}, {0, 1}}; + if (targetInfo.supportStMatrix()) { + src.push_back(ldstmatrixtrans); + } + if (targetInfo.supportLdMatrix()) { + dst.push_back(ldstmatrixtrans); + } + } + } + return {std::move(src), std::move(dst)}; +} + +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = LLVMFuncOp::create(b, op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} + +Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) { + assert(A.getNumInDims() == 1); + assert(A.getNumOutDims() == 1); + auto flatten = [](const std::vector> &matrix) { + SmallVector ret; + for (const auto &row : matrix) { + ret.push_back(row[0]); + } + return ret; + }; + auto nCol = A.getTotalInDimSizeLog2(); + auto nRow = A.getTotalOutDimSizeLog2(); + SmallVector matrix = flatten(A.getBases().begin()->second); + assert(matrix.size() == nCol); + + // Row-wise popcount to detect rows that appear exactly once across columns. + uint32_t rowsUnique = 0; + { + SmallVector rowPopCnt(nRow, 0); + for (int c = 0; c < nCol; ++c) { + uint32_t colBits = matrix[c]; + for (int r = 0; r < nRow; ++r) { + if (colBits & (1u << r)) + ++rowPopCnt[r]; + } + } + for (int r = 0; r < nRow; ++r) { + if (rowPopCnt[r] == 1) + rowsUnique |= 1u << r; + } + } + + // We iterate the matrix following the diagonals and build + // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique, + // then XOR everything else. This tends to encourage mad.lo codegen. + auto getMaskAndAllRowsUnique = [&](int i) -> std::pair { + uint32_t mask = 0; + int row = i < 0 ? -i : 0; + int col = i < 0 ? 0 : i; + bool allRowsUnique = true; + while (row < nRow && col < nCol) { + uint32_t bitValue = (matrix[col] >> row) & 1u; + mask |= bitValue << col; + allRowsUnique &= ((rowsUnique >> row) & 1u) == 1u; + ++row; + ++col; + } + return {mask, allRowsUnique}; + }; + + uint32_t explicitCols = 0; + + { + SmallVector masks; + for (int i = -nRow + 1; i < nCol; i++) { + masks.push_back(std::get<0>(getMaskAndAllRowsUnique(i))); + } + bool reachedFixedPoint = false; + while (!reachedFixedPoint) { + reachedFixedPoint = true; + for (uint32_t m : masks) { + uint32_t c = m & ~explicitCols; + if (llvm::isPowerOf2_32(c)) { + // found a single-element diagonal + explicitCols |= c; + reachedFixedPoint = false; + } + } + } + } + + // handle any diagonals that have survived + SmallVector ors; + SmallVector xors; + for (int i = -nRow + 1; i < nCol; i++) { + auto [mask, allRowsUnique] = getMaskAndAllRowsUnique(i); + mask &= ~explicitCols; + if (mask == 0) + continue; + auto masked = b.and_(x, b.i32_val(mask)); + auto shifted = i >= 0 ? Value(b.lshr(masked, b.i32_val(i))) + : Value(b.shl(masked, b.i32_val(-i))); + if (allRowsUnique) { + ors.push_back(shifted); + } else { + xors.push_back(shifted); + } + } + + // handle any explicit columns: + Value zero = b.i32_val(0); + for (int i = 0; i < nCol; i++) { + if ((explicitCols >> i) & 1) { + Value bit = b.and_(x, b.i32_val(1 << i)); + Value bit_is_zero = b.icmp_eq(bit, zero); + int32_t basis = matrix[i]; + if (basis == 0) + continue; + auto select = b.select(bit_is_zero, zero, b.i32_val(basis)); + if ((rowsUnique & basis) == basis) { + ors.push_back(select); + } else { + xors.push_back(select); + } + } + } + + auto treeReduce = [&](SmallVector &terms, + std::function op) -> Value { + if (terms.empty()) + return b.i32_val(0); + while (terms.size() > 1) { + SmallVector next; + for (size_t i = 0; i + 1 < terms.size(); i += 2) + next.push_back(op(terms[i], terms[i + 1])); + if (terms.size() % 2 == 1) + next.push_back(terms.back()); + terms = std::move(next); + } + return terms[0]; + }; + + auto orPart = treeReduce( + ors, [&b](Value x, Value y) { return b.or_(x, y, /*disjoint=*/true); }); + auto xorPart = + treeReduce(xors, [&b](Value x, Value y) { return b.xor_(x, y); }); + return b.or_(orPart, xorPart, /*disjoint=*/true); +} + +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(layout.getNumInDims() == indices.size()); + assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices))); + // Trivial layout + if (layout.getNumOutDims() == 0) { + return {}; + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + SmallVector> nonConstantIns; + for (auto [inDimName, idx] : indices) { + APInt constant; + if (matchPattern(idx, m_ConstantInt(&constant))) { + constantIns.push_back({inDimName, constant.getSExtValue()}); + } else { + constantIns.push_back({inDimName, 0}); + nonConstantIns.push_back({inDimName, idx}); + } + } + + // Compute constant part of the output and wrap it as values + Value zero = b.i32_val(0); + SmallVector> outIndices; + for (auto [outDimName, constant] : layout.apply(constantIns)) { + if (constant == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, b.i32_val(constant)}); + } + + if (nonConstantIns.size() == 0) { + return outIndices; + } + + SmallVector inDimNames; + // Concatenate input + Value x = b.i32_val(0); + int shift = 0; + for (auto [inDimName, idx] : nonConstantIns) { + inDimNames.push_back(inDimName); + x = b.or_(x, b.shl(idx, b.i32_val(shift))); + shift += layout.getInDimSizeLog2(inDimName); + } + + for (auto &[outDimName, outIdx] : outIndices) { + // Apply flattened sublayout for this output + auto matrix = layout.sublayout(inDimNames, outDimName).flattenIns(); + auto out = triton::gpu::matrixVectorProd(b, matrix, x); + outIdx = b.xor_(outIdx, out); + } + + return outIndices; +} + +std::optional getWarpGroupStartWarpId(Block *block) { + using namespace triton::gpu; + + // Look for an enclosing `ttg.warp_specialize` op. + while (block && block->getParentOp() && + !isa(block->getParentOp())) + block = block->getParentOp()->getBlock(); + if (!block || !block->getParentOp()) + return {}; + + auto partitions = cast(block->getParentOp()); + unsigned idx = block->getParent()->getRegionNumber(); + WarpSpecializeOp ws = partitions.getParentOp(); + std::optional> startIds = ws.getWarpGroupStartIds(); + assert(startIds && "cannot get warp group ID before warp group allocation"); + int32_t warpStartId = (*startIds)[idx]; + return warpStartId; +} + +std::optional getWarpGroupStartThreadId(Block *block) { + using namespace triton::gpu; + + std::optional warpStartId = getWarpGroupStartWarpId(block); + if (!warpStartId) + return {}; + + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp( + block->getParentOp()->getParentOfType()); + return *warpStartId * threadsPerWarp; +} + +Value getThreadId(OpBuilder &rewriter, Location loc) { + Value tid = + ::mlir::gpu::ThreadIdOp::create(rewriter, loc, ::mlir::gpu::Dimension::x); + tid = arith::IndexCastOp::create(rewriter, loc, i32_ty, tid); + + Operation *lookupPt = &rewriter.getInsertionBlock()->front(); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + int numWarps = triton::gpu::lookupNumWarps(lookupPt); + int upperBound = numWarps * threadsPerWarp; + + TritonLLVMOpBuilder b(loc, rewriter); + + // If this is being created inside a warp specialize op, compute the relative + // thread ID within the warp group. + if (std::optional startId = + getWarpGroupStartThreadId(rewriter.getInsertionBlock())) { + tid = arith::SubIOp::create(rewriter, loc, tid, b.i32_val(*startId)); + } + + assert(llvm::isPowerOf2_32(upperBound)); + // help LLVM's known bits analysis: + tid = b.and_(tid, b.i32_val(upperBound - 1)); + + return tid; +} + +std::pair getLaneAndWarpId(OpBuilder &rewriter, Location loc) { + TritonLLVMOpBuilder b(loc, rewriter); + Value tid = getThreadId(rewriter, loc); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + Value warpSizeVal = b.i32_val(threadsPerWarp); + + // If there is only one warp, the warp ID is always 0. + Operation *lookupPt = &rewriter.getInsertionBlock()->front(); + Value laneId; + Value warpId; + if (triton::gpu::lookupNumWarps(lookupPt) == 1) { + laneId = tid; + warpId = b.i32_val(0); + } else { + laneId = b.urem(tid, warpSizeVal); + warpId = mlir::triton::gpu::WarpIdOp::create(rewriter, loc, + /*omitUniformHint=*/true); + } + + return {laneId, warpId}; +} + +Value getLaneId(OpBuilder &rewriter, Location loc) { + return getLaneAndWarpId(rewriter, loc).first; +} + +// Helper function: applies linear layout vectorized over register indices +SmallVector>> +applyLinearLayoutVec(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices, + ArrayRef registers) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + StringAttr kRegister = str_attr("register"); + + // Precompute the base (with register = 0) + SmallVector> indicesWithZeroReg; + for (const auto &[attr, val] : indices) { + if (attr == kRegister) + indicesWithZeroReg.emplace_back(attr, b.i32_val(0)); + else + indicesWithZeroReg.emplace_back(attr, val); + } + + auto baseIndices = + applyLinearLayout(loc, rewriter, layout, indicesWithZeroReg); + + SmallVector>> ret; + + // Iterate over registers, applying XOR trick + for (auto reg : registers) { + SmallVector> constRegIndices; + for (const auto &[attr, val] : indices) { + constRegIndices.emplace_back(attr, attr == kRegister ? reg : 0); + } + auto regIndices = layout.apply(constRegIndices); + + SmallVector> combinedIndices; + for (auto [base, regIdx] : llvm::zip(baseIndices, regIndices)) { + assert(base.first == regIdx.first); + Value combined = b.xor_(base.second, b.i32_val(regIdx.second)); + combinedIndices.emplace_back(base.first, combined); + } + + ret.push_back(combinedIndices); + } + + return ret; +} + +// Refactored emitIndices function using applyLinearLayoutVec +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + LinearLayout ll = triton::gpu::toLinearLayout(shape, layout); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); + + SmallVector> commonIndices = { + {kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}; + + // Vectorize over registers + SmallVector registerIndices; + for (unsigned reg = 0; reg < ll.getInDimSize(kRegister); ++reg) + registerIndices.push_back(reg); + + auto vecIndices = + applyLinearLayoutVec(loc, rewriter, ll, commonIndices, registerIndices); + + unsigned rank = shape.size(); + SmallVector> ret; + for (auto &indices : vecIndices) { + SmallVector vals; + assert(indices.size() == rank); + for (auto &idx : indices) + vals.push_back(idx.second); + ret.push_back(vals); + } + + return ret; +} + +SmallVector> +getPaddedSharedShifts(Attribute enc, unsigned bitwidth, bool offsetInBytes) { + auto padded = dyn_cast(enc); + if (!padded) + return {}; + + SmallVector> shifts; + assert(bitwidth >= 8 && (bitwidth % 8 == 0) && + "bitwidth must be a positive multiple of 8 for padding"); + uint64_t offScale = offsetInBytes ? (bitwidth / 8) : 1; + for (auto [interval, padding] : + llvm::zip_equal(padded.getIntervals(), padded.getPaddings())) { + uint64_t intervalScaled = static_cast(interval) * offScale; + uint64_t paddingScaled = static_cast(padding) * offScale; + unsigned i = llvm::Log2_64(intervalScaled); + unsigned p = llvm::Log2_64(paddingScaled); + assert(i < 32 && p < 32 && "shift amount must be < 32 for i32 offsets"); + shifts.push_back({i, p}); + } + return shifts; +} + +Value applyPadding(Location loc, RewriterBase &rewriter, Value baseOffset, + ArrayRef> shifts) { + if (shifts.empty()) + return baseOffset; + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value pad = b.i32_val(0); + for (auto [i, p] : shifts) + pad = b.add(pad, b.shl(b.lshr(baseOffset, b.i32_val(i)), b.i32_val(p))); + return b.add(baseOffset, pad); +} + +uint32_t applyPadding(uint32_t baseOffset, + ArrayRef> shifts) { + uint64_t pad = 0; + for (auto [i, p] : shifts) + pad += (static_cast(baseOffset) >> i) << p; + uint64_t out = baseOffset + pad; + assert(out <= std::numeric_limits::max() && + "padded offset must be within 32-bit range"); + return static_cast(out); +} + +SmallVector +lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, Value smemBase, + ArrayRef> paddingShifts, + Value affineOffset, uint64_t maskSpanAffineOffset, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems, Operation *localLoadOp) { + + bool isStore = !valsArray.empty(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto emitLdSt = [&](RewriterBase &rewriter, Location loc, + ArrayRef vals, Value shmemAddr, int idx, + VectorType vecTy) -> SmallVector { + auto length = vecTy.getNumElements(); + if (isStore) { + Value valsVec = + packLLVector(loc, ArrayRef(vals).slice(idx, length), rewriter); + targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec, + /*pred=*/b.true_val()); + return {}; + } else { + assert(vals.empty()); + Value valsVec = + targetInfo.loadDShared(rewriter, loc, shmemAddr, std::nullopt, vecTy, + /*pred=*/b.true_val(), localLoadOp); + return unpackLLVector(loc, valsVec, rewriter); + } + }; + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase, + paddingShifts, affineOffset, maskSpanAffineOffset, laneId, + warpId, rewriter, targetInfo, maybeMaxVecElems, emitLdSt); +} + +SmallVector lowerLdSt( + Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, Value smemBase, + ArrayRef> paddingShifts, Value affineOffset, + uint64_t maskSpanAffineOffset, Value laneId, Value warpId, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems, + std::function(RewriterBase &, Location, ArrayRef, + Value, int, VectorType)> + lowerInst) { + auto vals = to_vector(valsArray); + bool isStore = !vals.empty(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto smemPtrTy = ptr_ty(ctx, 3); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kOffset = str_attr("offset"); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + + auto [elemsPerVec, permutation] = + largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems); + + cvt = permutation.apply(cvt); + if (isStore) { + vals = permutation.apply(vals); + } + + auto tile = LinearLayout::identity1D(elemsPerVec, kReg, kOffset); + auto quot = divideLeft(cvt, tile); + assert(quot.has_value() && "cvt must be divisible by tile"); + LinearLayout reps = zerosLike(tile) * *quot; + + LinearLayout addrLayout = + LinearLayout({{kLane, reps.getBases().lookup(kLane)}, + {kWarp, reps.getBases().lookup(kWarp)}}, + reps.getOutDims(), false); + auto [nAdditive, permStrides] = + actionAdditiveStrides(reps, addrLayout, maskSpanAffineOffset); + reps = permStrides.apply(reps); + if (isStore) { + vals = permStrides.apply(vals); + } + + // PTX expects the address increments to be done in bytes + // If we don't perform the computations in i8, the compiler would + // have to divide the computation by bitwdith / 8 and then lift this + // shl, which often it's not able to do. + auto i8Tile = + zerosLike(LinearLayout::identity1D(bitwidth / 8, kReg, kOffset)); + auto i8AddrLayout = i8Tile * addrLayout; + + auto regBaseI8 = + applyLinearLayout( + loc, rewriter, i8AddrLayout, + {{kReg, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}})[0] + .second; + + // It's fine that we don't compute the offset in bytes as affineOffset + // will be folded into a constant + auto affineOffsetI8 = b.mul(affineOffset, b.i32_val(bitwidth / 8)); + regBaseI8 = b.xor_(regBaseI8, affineOffsetI8); + SmallVector outVals; + auto vecTy = vec_ty(llvmElemTy, elemsPerVec); + for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) { + auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; + auto regIdxI8 = regIdx * (bitwidth / 8); + Value offset = b.xor_(regBaseI8, b.i32_val(regIdxI8)); + offset = applyPadding(loc, rewriter, offset, paddingShifts); + for (int j = 0; j < nAdditive; j += elemsPerVec) { + // all these constants will go as immediate values to LDS/STS + auto regIdxAdd = + reps.apply({{kReg, j}, {kLane, 0}, {kWarp, 0}})[0].second; + auto regIdxAddI8 = regIdxAdd * (bitwidth / 8); + // `actionAdditiveStrides` forces `regIdxAddI8` and `offset` to be bitwise + // disjoint, so we can calculate their padding contributions separately. + regIdxAddI8 = applyPadding(regIdxAddI8, paddingShifts); + Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8)); + auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset, + LLVM::GEPNoWrapFlags::inbounds); + llvm::append_range(outVals, + lowerInst(rewriter, loc, vals, vecAddr, i + j, vecTy)); + } + } + + // Permute the values back if we are loading + if (!isStore) { + auto invPermStrides = permStrides.inverse(); + outVals = invPermStrides.apply(outVals); + auto invPerm = permutation.inverse(); + outVals = invPerm.apply(outVals); + } + return outVals; +} + +SmallVector +lowerLocalLdSt(Location loc, MLIRContext *ctx, + LinearLayout cvt, // Map from registers to offset + ArrayRef valsArray, // Input for store, empty for load + Type llvmElemTy, triton::gpu::MemDescType srcTy, + SharedMemoryObject smemObj, RewriterBase &rewriter, + const TargetInfoBase &targetInfo, Operation *localLoadOp) { + assert(cvt.getNumOutDims() == 1); + assert(*cvt.getOutDimNames().begin() == str_attr("offset")); + + auto isStore = !valsArray.empty(); + // Remove broadcasting in the registers + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtCvt = removeBroadcastSrc.apply(cvt); + auto inVals = to_vector(valsArray); + if (isStore) { + inVals = removeBroadcastSrc.apply(inVals); + } + auto outVals = lowerLocalLdSt(loc, ctx, prmtCvt, inVals, llvmElemTy, srcTy, + smemObj, rewriter, targetInfo, localLoadOp); + if (!isStore) { + outVals = broadcastAs(outVals, cvt); + } + return outVals; + } + auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy); + auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy); + + std::optional maybeMaxVecElems; + SmallVector> paddingShifts; + if (auto paddedEnc = dyn_cast( + srcTy.getEncoding())) { + maybeMaxVecElems = paddedEnc.getMinInterval(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + paddingShifts = getPaddedSharedShifts(paddedEnc, bitwidth, + /*offsetInBytes=*/true); + } + + return lowerLdStShared(loc, ctx, cvt, valsArray, llvmElemTy, + smemObj.getBase(), paddingShifts, affineOffset, + maskSpanAffineOffset, rewriter, targetInfo, + maybeMaxVecElems, localLoadOp); +} + +SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = b.extract_val(type, llvmStruct, i); + } + return results; +} + +Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, + ValueRange resultVals, RewriterBase &rewriter, Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + llvm::report_fatal_error( + "size mismatch when packing elements for LLVM struct"); + } + Value llvmStruct = LLVM::UndefOp::create(rewriter, loc, structType); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (auto [i, value] : llvm::enumerate(resultVals)) { + assert(value && "unexpected null value"); + if (value.getType() != elementTypes[i]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << value); + emitError(loc) << "invalid element type in packLLElements. Expected " + << elementTypes[i] << " but got " << value.getType(); + llvm::report_fatal_error( + "element type mismatch when packing elements for LLVM struct"); + } + llvmStruct = b.insert_val(structType, llvmStruct, value, i); + } + return llvmStruct; +} + +SmallVector unpackLLVector(Location loc, Value llvmVec, + RewriterBase &rewriter) { + assert(bool(llvmVec) && "cannot unpack null value"); + if (llvmVec.getType().isIntOrIndexOrFloat() || + isa(llvmVec.getType()) || + isa(llvmVec.getType())) + return {llvmVec}; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector results; + for (int i = 0; i < cast(llvmVec.getType()).getNumElements(); + i++) { + results.push_back(b.extract_element(llvmVec, b.i32_val(i))); + } + return results; +} + +Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter) { + assert(vals.size() > 0); + auto vecType = vec_ty(vals[0].getType(), vals.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value vec = b.undef(vecType); + for (int i = 0; i < vals.size(); i++) { + vec = b.insert_element(vec, vals[i], b.i32_val(i)); + } + return vec; +} + +std::optional matchAtomicOp(RMWOp atomicOp) { + switch (atomicOp) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return LLVM::AtomicBinOp::max; + case RMWOp::MIN: + return LLVM::AtomicBinOp::min; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + return {}; + } +} + +std::optional getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return {}; + } +} + +llvm::MapVector getAllFreeVarMasks(MLIRContext *ctx) { + // Mask where all elements are redundant + auto kReg = str_attr("reg"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + int32_t fullMask = -1; + llvm::MapVector ret; + for (auto dimName : {kReg, kLane, kWarp, kBlock}) { + ret[dimName] = fullMask; + } + return ret; +} + +llvm::MapVector getFreeVariableMasks(Type type) { + auto ctx = type.getContext(); + auto tensorTy = dyn_cast(type); + if (!tensorTy) { + return getAllFreeVarMasks(ctx); + } + auto ll = triton::gpu::toLinearLayout(tensorTy); + return ll.getFreeVariableMasks(); +} + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type) { + MLIRContext *ctx = layout.getContext(); + auto shape = type.getShape(); + unsigned rank = shape.size(); + + auto ll = triton::gpu::toLinearLayout(type); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + SmallVector> offsets; + for (int i = 0; i < ll.getInDimSize(str_attr("register")); i++) { + auto idxs = ll.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + offsets.push_back( + llvm::to_vector_of(llvm::make_second_range(idxs))); + } + return offsets; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; + +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) { + auto i1ty = rewriter.getIntegerType(1); + return LLVM::ConstantOp::create(rewriter, loc, i1ty, + IntegerAttr::get(i1ty, v)); +} + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return LLVM::ConstantOp::create(rewriter, loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return LLVM::ConstantOp::create(rewriter, loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return LLVM::ConstantOp::create(rewriter, loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantBF16(Location loc, OpBuilder &rewriter, float v) { + APFloat apf(v); + bool ignored; + apf.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &ignored); + auto type = type::bf16Ty(rewriter.getContext()); + auto attr = FloatAttr::get(type, apf); + return LLVM::ConstantOp::create(rewriter, loc, type, attr); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return LLVM::ConstantOp::create(rewriter, loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return LLVM::ConstantOp::create(rewriter, loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return LLVM::ConstantOp::create( + rewriter, loc, type, + APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return LLVM::ConstantOp::create(builder, loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return LLVM::ConstantOp::create(builder, loc, ty, + builder.getIntegerAttr(ty, value)); +} + +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args) { + auto op = LLVM::CallOp::create(builder, loc, funcOp, args); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args) { + auto op = LLVM::CallIntrinsicOp::create(builder, loc, types, args); + op.getProperties().setIntrin(builder.getStringAttr(intrinsic)); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +SharedMemoryObject::SharedMemoryObject(Value base, Type baseElemType, + ArrayRef offsets) + : base(base), baseElemType(baseElemType), + offsets(offsets.begin(), offsets.end()) {} + +SharedMemoryObject::SharedMemoryObject(Value base, Type baseElemType, + int64_t rank, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + offsets.append(rank, b.i32_val(0)); +} + +SmallVector SharedMemoryObject::getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(offsets.begin(), offsets.end()); + return elems; +} + +SmallVector SharedMemoryObject::getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; +} + +Value SharedMemoryObject::getBaseBeforeSlice(int dim, Location loc, + RewriterBase &rewriter) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value cSwizzleOffset = getCSwizzleOffset(dim); + Value offset = b.sub(b.i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return b.gep(type, baseElemType, base, offset); +} + +uint64_t +SharedMemoryObject::getMaskSpanOffsets(triton::gpu::MemDescType srcTy) { + auto ctx = srcTy.getContext(); + auto shape = srcTy.getShape(); + auto allocShape = srcTy.getAllocShape(); + assert(allocShape.size() >= shape.size()); + assert(allocShape.size() - shape.size() <= 1); + allocShape = allocShape.take_back(shape.size()); + + // Early exist when there is no subview + if (allocShape == shape) { + return 0; + } + if (auto paddedEncoding = dyn_cast( + srcTy.getEncoding())) { + // Mask is used in fusion of constant part of memory operation address as + // immediate operand. Padded layout has additional address computations + // between main offset computation and actual memory access, which breaks + // constand fusing. Full mask disables this optimization. + return ~uint64_t(0); + } + auto totalLl = triton::gpu::toLinearLayout(allocShape, srcTy.getEncoding()); + auto dimNames = standardOutDimNames(ctx, shape.size()); + // Remove the kBlock dimension + auto kOffset = StringAttr::get(ctx, "offset"); + totalLl = totalLl.sublayout({kOffset}, dimNames); + // Map from dimNames to offset + auto invLl = totalLl.invert(); + SmallVector> logicalOffsets; + for (auto dim : standardOutDimNames(srcTy.getContext(), shape.size())) { + logicalOffsets.push_back({dim, 0}); + } + + auto ret = 0; + for (auto [dim, shapes] : llvm::enumerate(llvm::zip(shape, allocShape))) { + auto [shape, allocShape] = shapes; + for (int j = llvm::Log2_32(shape); j < llvm::Log2_32(allocShape); ++j) { + logicalOffsets[dim].second = 1 << j; + ret |= invLl.apply(logicalOffsets)[0].second; + } + // Reset the offset for the next dimension + logicalOffsets[dim].second = 0; + } + return ret; +} + +Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter, + triton::gpu::MemDescType srcTy) const { + auto ctx = srcTy.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // If it did not have a memdesc_subslice we don't need to compute the offset + // as it is zero + if (!isAffineSharedMemoryAccess(srcTy)) { + return b.i32_val(0); + } + + LinearLayout ll; + // We return the offset without the padding. The padding will be added in the + // lowering + if (auto paddedSharedEncoding = + dyn_cast( + srcTy.getEncoding())) { + ll = paddedSharedEncoding.getLinearComponent(); + } else { + ll = triton::gpu::toLinearLayout(srcTy); + } + + auto dimNames = standardOutDimNames(ctx, offsets.size()); + SmallVector> logicalOffsets; + for (auto [dim, offset] : llvm::zip(dimNames, offsets)) { + logicalOffsets.push_back({dim, offset}); + } + + ll = ll.sublayout({str_attr("offset")}, dimNames); + auto offset = + applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0].second; + return offset; +} + +Value SharedMemoryObject::getShmemAffineBase( + Location loc, RewriterBase &rewriter, + triton::gpu::MemDescType srcTy) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offset = getShmemOffset(loc, rewriter, srcTy); + return b.gep(base.getType(), baseElemType, base, offset); +} + +Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = LLVM::UndefOp::create(rewriter, loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = b.insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = b.extract_val(type, llvmStruct, i); + } + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*offsets=*/{elems.begin() + 1, elems.end()}}; +} + +Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() + kSharedMemoryOffset); + } + + auto mod = funcOp->getParentOfType(); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); + assert(globalBase); + return LLVM::AddressOfOp::create(rewriter, funcOp.getLoc(), globalBase); +} + +Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + const TargetInfoBase &targetInfo, + FunctionOpInterface funcOp, Value allocOffset = {}) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + // Base for this function + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() + + kGlobalScratchBufferOffset); + if (!allocOffset) { + return gmemBase; + } + + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return b.gep(ptrTy, i8_ty, gmemBase, allocOffset); + } + + // Base for entire kernel + auto gmemBase = + funcOp.getArgument(funcOp.getNumArguments() + kGlobalScratchBufferOffset); + + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto allocSizeAttr = mod.getOperation()->getAttrOfType( + "ttg.global_scratch_memory_size"); + if (!allocSizeAttr) { + return gmemBase; + } + + Value gridIdx[3]; + Value gridDim[2]; + for (int k = 0; k < 3; ++k) { + gridIdx[k] = GetProgramIdOp::create(rewriter, loc, k); + } + for (int k = 0; k < 2; ++k) { + gridDim[k] = GetNumProgramsOp::create(rewriter, loc, k); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value linearId = gridIdx[2]; + for (int k = 0; k < 2; ++k) { + linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k])); + } + auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + if (numCTAs > 1) { + linearId = b.mul(linearId, b.i32_val(numCTAs)); + linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc)); + } + + auto allocSize = allocSizeAttr.getValue().getZExtValue(); + + Value offset = b.mul(linearId, b.i32_val(allocSize)); + if (allocOffset) { + offset = b.add(offset, allocOffset); + } + + auto *ctx = rewriter.getContext(); + auto res = + b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); + return res; +} + +Value getProfileScratchPtr(Location loc, RewriterBase &rewriter, + FunctionOpInterface funcOp) { + // See NOTE: [Additional Function Arguments] + // FIXME(Keren): This is broken when we have device functions, we + // need to implement proper calling convention + return funcOp.getArgument(funcOp.getNumArguments() + + kProfileScratchBufferOffset); +} + +Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + target.getSharedAddressSpace()); + auto func = op->template getParentOfType(); + if (!func) + func = cast(op); + + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offVal = b.i32_val(offset); + Value base = + b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} + +// Extract the bits of `a` that are set in `mask` +Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(a.getType() == i32_ty && "a must be i32"); + // Handle width = 32 to avoid doing 1 << 32 + if (mask == 0xFFFFFFFF) + return a; + + // Implements the blocked algorithm from + // https://forums.developer.nvidia.com/t/pdep-and-pext-functionality-for-cuda/270973 + uint32_t mskConst = mask; + uint32_t extcnt = 0; + Value result = b.i32_val(0); + while (mskConst) { + uint32_t oldmsk = mskConst; + uint32_t bitgrplsb = mskConst & (-mskConst); + mskConst &= bitgrplsb + mskConst; + uint32_t bitgrp = mskConst ^ oldmsk; + uint32_t lsbpos = 31 - __builtin_clz(bitgrplsb); + // like popcount for a number 0..01..1..0 but portable + uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos)); + uint32_t shift = lsbpos - extcnt; + extcnt += grplen; + result = + b.or_(result, b.lshr(b.and_(b.i32_val(bitgrp), a), b.i32_val(shift))); + } + return result; +} + +// Puts the bits of `a` that are set in `mask` into the bits of `result` +Value pdep_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(a.getType() == i32_ty && "a must be i32"); + + if (mask == 0) + return b.i32_val(0); + assert(mask < 64 && "mask must be less than 64"); + + // Blocked algorithm (same grouping trick as the pext example). + uint32_t mskConst = mask; + uint32_t depcnt = 0; // how many source bits from `a` we've consumed + Value result = b.i32_val(0); + + while (mskConst) { + uint32_t oldmsk = mskConst; + + // Isolate lsb set bit, then clear the lowest contiguous run of 1s. + uint32_t bitgrplsb = mskConst & (~mskConst + 1); // m & -m + mskConst &= (bitgrplsb + mskConst); + uint32_t bitgrp = mskConst ^ oldmsk; // the cleared run (contiguous 1s) + + // Group start position and length. + uint32_t lsbpos = __builtin_ctz(bitgrplsb); + uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos)); + + // Align the next grplen bits of `a` to the group's lsb, then mask to the + // group. + uint32_t shift = + lsbpos - depcnt; // non-negative invariant for this traversal order + depcnt += grplen; + + Value deposited = b.and_(b.shl(a, b.i32_val(shift)), b.i32_val(bitgrp)); + result = b.or_(result, deposited); + } + + return result; +} + +std::tuple, Value> +delinearize(RewriterBase &rewriter, Location loc, + triton::gpu::DistributedEncodingTrait layout, + ArrayRef shape, StringAttr dimName, Value linear) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ll = triton::gpu::toLinearLayout(shape, layout); + assert(ll.hasInDim(dimName)); + int32_t freeVarMask = ll.getFreeVariableMasks()[dimName]; + auto isRepresentative = b.true_val(); + if (freeVarMask != 0) { + isRepresentative = + b.icmp_eq(b.and_(b.i32_val(freeVarMask), linear), b.i32_val(0)); + // We remove the bits of linear that are set to one in freeVarMask + int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1); + linear = pext_i32(rewriter, loc, linear, nonFreeVarMask); + } + + auto linearLayout = triton::gpu::LinearEncodingAttr::get( + rewriter.getContext(), std::move(ll)); + auto orderDim = linearLayout.orderPerDim(dimName, linearLayout.getOrder()); + auto shapeDim = linearLayout.basesPerDim(dimName); + auto multiDim = delinearize(rewriter, loc, linear, shapeDim, orderDim); + + return std::make_tuple(std::move(multiDim), isRepresentative); +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = b.i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = b.i32_val(en.value()); + multiDim[en.index()] = b.urem(remained, dimSize); + remained = b.udiv(remained, dimSize); + } + return multiDim; +} + +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order) { + auto rank = shape.size(); + assert(order.size() == rank); + SmallVector multiDim(rank); + for (auto dim : order) { + multiDim[dim] = linear % shape[dim]; + linear /= shape[dim]; + } + assert(linear == 0); + return multiDim; +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto rank = multiDim.size(); + Value linear = b.i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.i32_val(dimShape); + linear = b.add(b.mul(linear, dimSize), dim); + } + } + return linear; +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + triton::gpu::LinearEncodingAttr encoding, StringAttr dimName) { + auto orderDim = encoding.orderPerDim(dimName, encoding.getOrder()); + auto shapeDim = encoding.basesPerDim(dimName); + auto linear = linearize(rewriter, loc, multiDim, shapeDim, orderDim); + auto ll = encoding.getLinearLayout(); + int32_t freeVarMask = ll.getFreeVariableMasks().lookup(dimName); + if (freeVarMask != 0) { + int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1); + linear = pdep_i32(rewriter, loc, linear, nonFreeVarMask); + } + return linear; +} + +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + size_t linear = 0; + for (unsigned dim : llvm::reverse(order)) + linear = linear * shape[dim] + multiDim[dim]; + return linear; +} + +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + LLVM::GlobalOp global; + { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = LLVM::GlobalOp::create(rewriter, UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, + LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = b.i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + b.gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + return stringStart; +} + +} // namespace LLVM + +Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value ret = b.i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = b.add(ret, b.mul(offset, stride)); + } + return ret; +} + +// Isolated a single warp specialize op from above. +static void +makeWarpGroupsIsolatedFromAbove(triton::gpu::WarpSpecializeOp wsOp) { + SetVector captures; + auto partOp = wsOp.getPartitionOp(); + getUsedValuesDefinedAbove(partOp.getPartitionRegions(), captures); + for (Value capture : captures) { + partOp->insertOperands(partOp.getNumOperands(), capture); + for (Region ®ion : partOp.getPartitionRegions()) { + BlockArgument arg = + region.addArgument(capture.getType(), capture.getLoc()); + replaceAllUsesInRegionWith(capture, arg, region); + } + } +} + +void makeAllWarpGroupsIsolatedFromAbove(Operation *op) { + op->walk([](triton::gpu::WarpSpecializeOp wsOp) { + makeWarpGroupsIsolatedFromAbove(wsOp); + }); +} + +// TODO: Is there a better way to do this? This needs to be fixed upstream. +void fixUpLoopAnnotation(ModuleOp mod) { + mod->walk([](Operation *op) { + if (isa(op)) { + if (op->hasAttr("llvm.loop_annotation")) { + auto loopMD = dyn_cast( + op->getAttr("llvm.loop_annotation")); + if (loopMD) { + if (auto brOp = dyn_cast(op)) { + brOp.setLoopAnnotationAttr(loopMD); + } else if (auto condBrOp = dyn_cast(op)) { + condBrOp.setLoopAnnotationAttr(loopMD); + } + } + } + } + }); +} + +SmallVector inlineRegionImpl(RewriterBase &rewriter, Region ®ion, + ArrayRef args, + mlir::TypeID terminatorTypeId, + Location loc) { + // Inline regions with multiple blocks + // + // Before After + // ┌─────────┐ + // │ op1 │ + // ┌──────────┐ │ cf.br │ + // │region[0] │ └────┬────┘ + // │cf.cond_br├─┐ ┌────▼─────┐ + // └────┬─────┘ │ │region[0] │ + // │ │ │cf.cond_br├─┐ + // ┌───────┐ ┌────▼────┐ │ └────┬─────┘ │ + // │ op1 │ IP │region[1]│ │ ┌────▼────┐ │ + // │ │◄─── │yield ...│ │ │region[1]│ │ + // │ op2 │ └─────────┘ │ ┌─┤cf.br │ │ + // └───────┘ │ │ └─────────┘ │ + // ┌─────────┐ │ │ ┌─────────┐ │ + // │region[2]│◄─┘ │ │region[2]│◄─┘ + // │yield │ │ │cf.br │ + // └─────────┘ │ └────┬────┘ + // │ ┌────▼────┐ + // └►│op2 │ + // └─────────┘ + auto *curBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *remainingOpsBlock = rewriter.splitBlock(curBlock, opPosition); + + IRMapping regionMap; + Region &parent = *curBlock->getParent(); + rewriter.cloneRegionBefore(region, parent, parent.end(), regionMap); + rewriter.setInsertionPointToEnd(curBlock); + LLVM::BrOp::create(rewriter, loc, args, regionMap.lookup(®ion.front())); + + ValueRange terminatorOperands; + for (Block &origBlock : region) { + Block *newBlock = regionMap.lookup(&origBlock); + rewriter.moveBlockBefore(newBlock, remainingOpsBlock); + + auto terminator = newBlock->getTerminator(); + if (terminator->getRegisteredInfo()->getTypeID() == terminatorTypeId) { + terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointAfter(terminator); + rewriter.replaceOpWithNewOp(terminator, terminatorOperands, + remainingOpsBlock); + } + } + + rewriter.setInsertionPointToStart(remainingOpsBlock); + SmallVector vals; + for (auto resultTy : terminatorOperands.getType()) { + auto val = remainingOpsBlock->addArgument(resultTy, loc); + vals.push_back(val); + } + return vals; +} + +void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy, + ConversionPatternRewriter &rewriter, + SmallVector &resultVals, + Type valueElemTy, TritonLLVMOpBuilder &b, + Value threadPred, + const TargetInfoBase &targetInfo, + const LLVMTypeConverter *typeConverter) { + auto *ctx = rewriter.getContext(); + auto loc = op->getLoc(); + Type structTy = typeConverter->convertType(tensorTy); + if (!op->hasAttr("allocation.offset")) { + // No broadcasting, just pack the values into a struct + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + return; + } + + auto dstLayout = triton::gpu::toLinearLayout(tensorTy); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + dstLayout = dstLayout.sublayout({kReg, kLane, kWarp}, + llvm::to_vector(dstLayout.getOutDimNames())); + dstLayout = dstLayout.reshapeOuts( + {{str_attr("offset"), dstLayout.getTotalOutDimSize()}}); + auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + auto emitSt = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, + Value shmemAddr, int idx, + VectorType vecTy) -> SmallVector { + auto length = vecTy.getNumElements(); + Value valsVec = + packLLVector(loc, ArrayRef(vals).slice(idx, length), rewriter); + targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec, + threadPred); + return {}; + }; + + auto emitLd = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, + Value shmemAddr, int idx, + VectorType vecTy) -> SmallVector { + Value loadedVec = targetInfo.loadDShared(rewriter, loc, shmemAddr, + std::nullopt, vecTy, b.true_val()); + return unpackLLVector(loc, loadedVec, rewriter); + }; + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase, + /*paddingShifts=*/{}, /*affineOffset=*/b.i32_val(0), + /*maskSpanAffineOffset=*/0, laneId, warpId, rewriter, targetInfo, + /*maybeMaxVecElems=*/{}, emitSt); + b.barrier(triton::gpu::AddrSpace::Local); + + resultVals = lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase, + /*paddingShifts=*/{}, /*affineOffset=*/b.i32_val(0), + /*maskSpanAffineOffset=*/0, laneId, warpId, rewriter, + targetInfo, /*maybeMaxVecElems=*/{}, emitLd); + + // Create the result struct and replace the operation + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); +} + +// Only retain those attributes that are not constructed by +// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument +// attributes. +void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + attr.getName() == triton::gpu::AttrNumWarpsName || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } +} + +triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + // Push back two new arguments that indicate the current pointer to shared + // memory and global scratch memory. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto sharedPtrTy = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); + auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1); + auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1); + + // 1. Modify the function type to add the new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + bool isKernel = triton::isKernel(funcOp); + if (isKernel && targetInfo.isCuda()) { + for (auto i : llvm::seq(amendedInputTy.size())) { + if (isa(amendedInputTy[i])) { + funcOp.setArgAttr(i, "tt.nv_tma_desc", + mlir::IntegerAttr::get(i32_ty, 1)); + } + } + } + if (!isKernel) { + amendedInputTy.push_back(sharedPtrTy); + } + amendedInputTy.push_back(globalPtrTy); + amendedInputTy.push_back(profilePtrTy); + auto amendedFuncTy = + FunctionType::get(ctx, amendedInputTy, funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + if (auto argAttrs = funcOp.getAllArgAttrs()) { + llvm::SmallVector amendedArgAttrs(argAttrs.begin(), + argAttrs.end()); + while (amendedArgAttrs.size() < amendedInputTy.size()) { + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + } + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + } + + // 3. Add the new arguments to the region + auto amendedFuncOp = triton::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + if (!isKernel) { + region.addArgument(sharedPtrTy, loc); + } + region.addArgument(globalPtrTy, loc); + region.addArgument(profilePtrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; +} + +void handleArgPtrDatatype(triton::FuncOp funcOp, LLVM::LLVMFuncOp &llvmFuncOp) { + // The convertion from triton::PointerType to LLVM::LLVMPointerType losts + // the pointee datatype information. + // This function add back the pointee datatype information to arg attribute. + FunctionType fty = funcOp.getFunctionType(); + for (unsigned i = 0; i < fty.getNumInputs(); ++i) { + auto argType = fty.getInput(i); + if (auto argPtrType = dyn_cast(argType)) { + auto argDType = argPtrType.getPointeeType(); + llvmFuncOp.setArgAttr(i, "tt.pointee_type", + mlir::TypeAttr::get(argDType)); + } + } +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 0000000000..15e373d87b --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,645 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { + +Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) { + if (isa(val.getType()) && + !isa(type)) { + return b.ptrtoint(type, val); + } else { + return b.bitcast(val, type); + } +} + +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = bitOrPtrCast(constVal, intTy, b); + Value vec = b.undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); + constVal = vec; + } + Value llSrc = bitOrPtrCast(constVal, srcType, b); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; + +struct UnsplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult matchAndRewrite(triton::UnsplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto scrVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + rewriter.replaceOp(op, scrVals[0]); + return success(); + } +}; + +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return failure(); + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + // Lower FP8 constant to int8 constant since FP8 types are not supported on + // LLVM IR. + if (type::isFloat8(elemType)) + elemType = rewriter.getIntegerType(8); + auto constOp = LLVM::ConstantOp::create(rewriter, loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; + +// Convert arith::ConstantOp with an array DenseElementsAttr to a +// LLVM::StructType value. +struct ArithConstantArrayOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return failure(); + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + if (mlir::isa(value)) + return failure(); + auto tensorTy = cast(op.getType()); + auto loc = op->getLoc(); + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + SmallVector llVals; + for (auto v : values.getValues()) { + auto ll = LLVM::ConstantOp::create(rewriter, loc, elemType, v); + llVals.push_back(ll); + } + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + + if (elemsPerThread != llVals.size()) { + op->emitError( + "Right now we only support constant arrays with the same number of " + "elements as the number of threads per warp"); + return failure(); + } + auto llStruct = + packLLElements(loc, getTypeConverter(), llVals, rewriter, op.getType()); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; + +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + // Note: We must explicitly handle broadcasted registers. The LLVM lowering + // generally represents broadcasted register bits by *duplicating* elements + // in the LLVM struct. Many conversions operate on a "stripped" (no-bcast) + // view and then re-introduce broadcasting at the end (see + // ConvertLayoutOpConversion). + StringAttr kReg = StringAttr::get(rewriter.getContext(), "register"); + + // Unpack input values. + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + + // Strip broadcasted registers from inputs. + auto lhsTy = cast(op.getLhs().getType()); + auto rhsTy = cast(op.getRhs().getType()); + auto lhsLayout = toLinearLayout(lhsTy); + auto rhsLayout = toLinearLayout(rhsTy); + auto removeBroadcastLhs = actionRemoveBroadcastedRegs(lhsLayout); + auto removeBroadcastRhs = actionRemoveBroadcastedRegs(rhsLayout); + if (!removeBroadcastLhs.isIdentity()) + lhsVals = removeBroadcastLhs.apply(lhsVals); + if (!removeBroadcastRhs.isIdentity()) + rhsVals = removeBroadcastRhs.apply(rhsVals); + + // Compute the expected non-broadcast register count for the result. + auto dstLayout = toLinearLayout(resultTy); + auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout); + auto strippedDstLayout = removeBroadcastDst.apply(dstLayout); + + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + + if (retVals.size() != strippedDstLayout.getInDimSize(kReg)) { + return op->emitError() + << "tt.cat lowering expected " + << strippedDstLayout.getInDimSize(kReg) + << " (non-broadcast) register values for the result, but got " + << retVals.size() + << ". (hint: this usually means the operands/result encodings are " + "incompatible for the current CatOp lowering)"; + } + + // Re-introduce broadcasting if the destination expects it. + if (!removeBroadcastDst.isIdentity()) + retVals = broadcastAs(retVals, dstLayout); + + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We can count how many contiguous + // registers belong to the same chunk then we merge the registers between + // two different chunks. + Location loc = op->getLoc(); + RankedTensorType dstTy = op.getType(); + auto ll = toLinearLayout(dstTy); + int splitDim = dstTy.getRank() - 1; + auto kReg = mlir::StringAttr::get(dstTy.getContext(), "register"); + const auto &bases = ll.getBases(); + const auto ®s = bases.find(kReg)->second; + int numContiguousValues = 1; + bool found = false; + for (const auto ® : regs) { + if (reg[splitDim] == 1) { + found = true; + break; + } + numContiguousValues *= 2; + } + assert(found && "Join dimension is not distributed along registers."); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + joinedVals.resize(lhsVals.size() * 2); + for (int i = 0; i < lhsVals.size(); i += numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + joinedVals[2 * i + j] = lhsVals[i + j]; + joinedVals[2 * i + numContiguousValues + j] = rhsVals[i + j]; + } + } + auto typeConverter = getTypeConverter(); + Value ret = packLLElements(loc, typeConverter, joinedVals, rewriter, dstTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The layout distribute the last dimension along registers + // - The last dimension (the one we're splitting) has sizePerThread=2, + // threadPerWarp=1 and warpPerBlock=1. + // + // With these invariants, split is trivial: We can count how many contiguous + // registers belong to the same chunk then we separate the registers between + // two different chunks. + auto srcTy = cast(op.getSrc().getType()); + auto ll = toLinearLayout(srcTy); + int splitDim = srcTy.getRank() - 1; + auto kReg = mlir::StringAttr::get(srcTy.getContext(), "register"); + const auto &bases = ll.getBases(); + const auto ®s = bases.find(kReg)->second; + int numContiguousValues = 1; + bool found = false; + for (const auto ® : regs) { + if (reg[splitDim] == 1) { + found = true; + break; + } + numContiguousValues *= 2; + } + assert(found && "Split dimension is not distributed along registers."); + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2 * numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + outLhsVals.push_back(srcVals[i + j]); + outRhsVals.push_back(srcVals[i + numContiguousValues + j]); + } + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct MemDescTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.getBase(), srcSmemObj.getBaseElemType(), + /*offsets=*/applyPermutation(srcSmemObj.getOffsets(), op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescReshapeOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector offsets = srcSmemObj.getOffsets(); + // FIXME: This should be done by composing a linear layout with its + // reshaped counterpart. + SmallVector srcShape; + for (int64_t d : op.getSrc().getType().getShape()) + srcShape.push_back(d); + SmallVector dstShape; + for (int64_t d : op.getType().getShape()) + dstShape.push_back(d); + Value linearOffset = LLVM::linearize(rewriter, loc, offsets, srcShape); + SmallVector delinearizedOffset = + LLVM::delinearize(rewriter, loc, linearOffset, dstShape); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.getBase(), srcSmemObj.getBaseElemType(), delinearizedOffset); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By construction, TransOp::inferReturnTypes ensures that the src encoding + // is the same as the dst encoding so that this op is a no-op. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescIndexOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescIndexOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto *ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + // getAllocationShapePerCTA returns the correct number fp4 elements that we + // need to skip when we have fp4Padded=True. getShapePerCTA does not account + // for this + auto stride = product( + getAllocationShapePerCTA(dstTy.getEncoding(), dstTy.getShape())); + Value offset = b.mul(op.getIndex(), b.i32_val(stride)); + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto base = smemObj.getBase(); + auto elemPtrTy = base.getType(); + auto prevOffsets = smemObj.getOffsets(); + SmallVector offsetVals(prevOffsets.end() - dstTy.getRank(), + prevOffsets.end()); + + // Apply padding based on the amount we move the base ptr + if (auto padEnc = dyn_cast(dstTy.getEncoding())) { + auto bitwidth = dstTy.getElementTypeBitWidth(); + auto paddingShifts = getPaddedSharedShifts(padEnc, bitwidth, + /*offsetInBytes=*/false); + offset = applyPadding(loc, rewriter, offset, paddingShifts); + } + + // Advance the pointer and keep the opOffsets as the new shape + smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset), + llvmElemTy, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescSubsliceOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubsliceOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubsliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto *ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto destTy = op.getResult().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto layoutOrder = getOrder(srcTy); + auto enc = srcTy.getEncoding(); + + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto opOffsetVals = op.getOffsets(); + + auto base = smemObj.getBase(); + auto elemPtrTy = base.getType(); + // Accumulate the logical offsets + SmallVector offsetVals; + for (auto [oldOffVal, opOff] : + llvm::zip(smemObj.getOffsets(), opOffsetVals)) { + offsetVals.push_back(b.add(oldOffVal, b.i32_val(opOff))); + } + smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescReinterpretOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite(MemDescReinterpretOp op, OpAdaptor adaptor, + ConversionPatternRewriter &b) const override { + Location loc = op.getLoc(); + MemDescType srcTy = op.getSrc().getType(); + MemDescType dstTy = op.getType(); + Type srcElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + Type dstElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + + auto smemObj = + getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), srcElemTy, b); + Value newBase = smemObj.getShmemAffineBase(loc, b, srcTy); + SharedMemoryObject newObj(newBase, dstElemTy, dstTy.getRank(), loc, b); + b.replaceOp(op, getStructFromSharedMemoryObject(loc, newObj, b)); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add( + typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add( + typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.cpp new file mode 100644 index 0000000000..8c86c0ba8f --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.cpp @@ -0,0 +1,376 @@ +#include "triton/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OperationSupport.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// convertOpTypes +//===----------------------------------------------------------------------===// + +void mlir::triton::convertOpTypes(Operation *op, + const TypeConverter &typeConverter) { + ImplicitLocOpBuilder b(op->getLoc(), op); + // WarpSpecializePartitionsOp exists in a region that must only contain a + // single op. This also means that we know that its operands always dominate + // the enclosing WarpSpecializeOp, so we can insert the casts there instead. + if (isa(op)) + b.setInsertionPoint(op->getParentOp()); + SmallVector operands = llvm::to_vector(op->getOperands()); + for (Value &operand : operands) { + Type type = typeConverter.convertType(operand.getType()); + if (type != operand.getType()) { + operand = + UnrealizedConversionCastOp::create(b, type, operand).getResult(0); + } + } + op->setOperands(operands); + + for (Region ®ion : op->getRegions()) { + b.setInsertionPointToStart(®ion.front()); + for (BlockArgument arg : llvm::to_vector(region.getArguments())) { + Type type = typeConverter.convertType(arg.getType()); + BlockArgument newArg = region.addArgument(type, arg.getLoc()); + auto cast = UnrealizedConversionCastOp::create(b, arg.getType(), newArg); + arg.replaceAllUsesWith(cast.getResult(0)); + region.eraseArgument(0); + } + } + + SmallVector resultTypes; + (void)typeConverter.convertTypes(op->getResultTypes(), resultTypes); + if (TypeRange(resultTypes) == op->getResultTypes()) + return; + OperationState state(op->getLoc(), op->getName(), op->getOperands(), + resultTypes, op->getAttrs()); + for (Region ®ion : op->getRegions()) + state.addRegion()->takeBody(region); + b.setInsertionPoint(op); + Operation *newOp = b.create(state); + + SmallVector results; + for (auto [i, result, type] : + llvm::enumerate(newOp->getResults(), op->getResultTypes())) { + auto cast = UnrealizedConversionCastOp::create(b, type, result); + op->getResult(i).replaceAllUsesWith(cast.getResult(0)); + } + op->erase(); +} + +//===----------------------------------------------------------------------===// +// elideTrivialCaptures +//===----------------------------------------------------------------------===// + +static LogicalResult findTrivialSubcomputation(LLVM::LLVMFuncOp func, + Value capture, + SetVector &ops) { + SetVector worklist; + worklist.insert(capture); + for (unsigned i = 0; i != worklist.size(); ++i) { + Value capture = worklist[i]; + // Check for a kernel argument. + if (auto arg = dyn_cast(capture)) { + if (arg.getOwner() == &func.getBody().front()) + continue; + // Otherwise, this is some other block argument that cannot be elided. + return failure(); + } + + Operation *op = capture.getDefiningOp(); + // Check if the defining op can be rematerialized. At the LLVM level, + // checking for pure is probably a good enough heuristic. + if (isPure(op)) { + ops.insert(op); + worklist.insert(op->operand_begin(), op->operand_end()); + continue; + } + // The op cannot be rematerialized. + return failure(); + } + + // Cap the number of ops that can be rematerialized. + // FIXME: This is arbitrary. + return success(ops.size() <= 16); +} + +void mlir::triton::elideTrivialCaptures(LLVM::LLVMFuncOp func, + ArrayRef wsOps) { + // The goal is to completely eliminate captures by hoisting or rematerializing + // computations. We could minimize captures by rematerializing + // subcomputations, but that is much more complicated. Prefer rematerializing + // because that reduces liveranges. If subgraphs are duplicated more than + // once, we will rely on CSE to clean them up. + SetVector subgraph; + for (WarpSpecializeOp wsOp : wsOps) { + auto partOp = wsOp.getPartitionOp(); + llvm::BitVector toErase(partOp.getNumOperands()); + for (auto [i, capture] : llvm::enumerate(partOp.getExplicitCaptures())) { + subgraph.clear(); + if (failed(findTrivialSubcomputation(func, capture, subgraph))) + continue; + toErase.set(i); + subgraph = topologicalSort(subgraph); + + for (Region *region : wsOp.getPartitionRegions()) { + OpBuilder b(region); + IRMapping mapping; + for (Operation *op : subgraph) { + b.clone(*op, mapping); + } + Value remat = capture; + if (!subgraph.empty()) { + unsigned resultIdx = cast(capture).getResultNumber(); + remat = mapping.lookup(subgraph.back())->getResult(resultIdx); + } + region->getArgument(i).replaceAllUsesWith(remat); + } + } + + partOp->eraseOperands(toErase); + for (Region *region : wsOp.getPartitionRegions()) { + region->front().eraseArguments(toErase); + } + } +} + +/// Disable LICM (Loop Invariant Code Motion) for a loop. This prevents LLVM +/// from hoisting code out of the switch loop generated by the +/// `ttg.warp_specialize` lowering, which could result in long liveranges and +/// cause register spilling in partition regions. +static void disableLICM(LLVM::BrOp latchBr) { + Builder b(latchBr.getContext()); + MLIRContext *ctx = b.getContext(); + auto licmMD = LLVM::LoopLICMAttr::get(ctx, b.getBoolAttr(true), {}); + auto loopMD = + LLVM::LoopAnnotationAttr::get(b.getContext(), {}, {}, {}, {}, {}, licmMD, + {}, {}, {}, {}, {}, {}, {}, {}, {}); + latchBr.setLoopAnnotationAttr(loopMD); +} + +//===----------------------------------------------------------------------===// +// lowerWarpSpecializeCommon +//===----------------------------------------------------------------------===// + +static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop, + const TargetInfoBase &targetInfo, + const WarpSpecializeCallbacks &callbacks, + unsigned switchLoopBarrierIdx) { + TritonLLVMIRRewriter b(ws.getLoc(), ws.getContext()); + for (Region *partition : ws.getPartitionRegions()) { + // Load the explicit captures from shared memory and replace the block args + // if there are any. + b.setInsertionPointToStart(&partition->front()); + + callbacks.reallocRegisters(b, ws, + RegisterReallocPhase::WorkerPartitionStart, + partition->getRegionNumber()); + + if (partition->getNumArguments()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(partition->getArgumentTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + LLVM::LLVMPointerType ptrTy = ptr_ty(b.getContext(), 3); + for (auto [i, arg] : + llvm::zip(llvm::seq(partition->getNumArguments()), + partition->getArguments())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, i}); + // Each thread in the warp group needs a copy of the value. + Value value = b.load(arg.getType(), ptr, /*align=*/1); + arg.replaceAllUsesWith(value); + } + partition->front().eraseArguments([](auto) { return true; }); + } + + // The shared memory is only live for the entry into the region, so put + // another barrier here. + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + + // Rewrite all warp returns. + partition->walk([&](WarpReturnOp op) { + TritonLLVMIRRewriter b(op.getLoc(), op); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + callbacks.reallocRegisters(b, ws, + RegisterReallocPhase::WorkerPartitionEnd, + partition->getRegionNumber()); + b.replaceOpWithNewOp(op, switchLoop); + }); + } +} + +LogicalResult mlir::triton::lowerWarpSpecializeCommon( + LLVM::LLVMFuncOp func, ArrayRef wsOps, Block *entry, + Block *header, Block *switchLoop, Value wid, MLIRContext *ctx, + unsigned defaultNumWarps, unsigned totalNumWarps, + const TargetInfoBase &targetInfo, const WarpSpecializeCallbacks &callbacks, + unsigned switchLoopBarrierIdx) { + + TritonLLVMIRRewriter b(func.getLoc(), ctx); + Type int8Type = b.getIntegerType(8); + LLVM::LLVMPointerType ptrTy = ptr_ty(ctx, 3); + + b.setInsertionPointToStart(switchLoop); + callbacks.reallocRegisters(b, wsOps[0], RegisterReallocPhase::SwitchLoopStart, + 0); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value relWid = b.sub(wid, b.i32_val(defaultNumWarps)); + + // The default warp group will populate the state pointer with the state ID + // for all warps. + // %warp_state_ptr = getelementptr ptr %state_tr[%rel_wid] + // %warp_state = load i8 %warp_state_ptr + Value warpStatePtr = b.gep(ptrTy, int8Type, statePtr, relWid); + // All threads in a warp reading from the same smem address will not create + // bank conflicts and is better than predicated load. + Value warpState = b.load(int8Type, warpStatePtr); + + // Pull the partition regions out. Switch based on the state ID to the right + // partition. + SmallVector partitionBlocks; + SmallVector partitionStates; + int32_t partitionStateCounter = 0; + // This represents the data that the default warp group will fill into the + // state pointer before entering each `warp_specialize` region, which maps + // a warp ID to a state ID in the switch. + int32_t maxNumWarps = totalNumWarps - defaultNumWarps; + SmallVector> warpToState( + wsOps.size(), SmallVector(maxNumWarps, -1)); + + for (size_t i = 0; i < wsOps.size(); ++i) { + WarpSpecializeOp op = wsOps[i]; + auto &stateMap = warpToState[i]; + rewritePartitionRegions(op, switchLoop, targetInfo, callbacks, + switchLoopBarrierIdx); + for (auto [partition, partitionNumWarps, startId] : + llvm::zip(op.getPartitionRegions(), op.getPartitionNumWarps(), + *op.getWarpGroupStartIds())) { + partitionStates.push_back(partitionStateCounter++); + partitionBlocks.push_back(&partition->front()); + for (int32_t &stateId : MutableArrayRef(stateMap).slice( + startId - defaultNumWarps, partitionNumWarps)) + stateId = partitionStates.back(); + } + } + + if (partitionStateCounter > std::numeric_limits::max()) { + return mlir::emitError(func.getLoc(), + "FIXME: too many warp group partitions"); + } + + // Splice them in reverse order so the IR is easier to read. + Region::BlockListType &funcBlocks = func.getBody().getBlocks(); + for (Block *block : llvm::reverse(partitionBlocks)) { + Region *region = block->getParent(); + funcBlocks.splice(std::next(switchLoop->getIterator()), + region->getBlocks()); + } + + // Default destination. + Block *defaultBlock = new Block; + funcBlocks.insert(std::next(switchLoop->getIterator()), defaultBlock); + b.setInsertionPointToStart(defaultBlock); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + auto latchBr = LLVM::BrOp::create(b, b.getLoc(), switchLoop); + disableLICM(latchBr); + + // Exit state. + Block *switchExit = new Block; + funcBlocks.insert(std::next(defaultBlock->getIterator()), switchExit); + partitionBlocks.push_back(switchExit); + partitionStates.push_back(partitionStateCounter); + + // Create the switch. + b.setInsertionPointToEnd(switchLoop); + SmallVector caseValues; + for (int32_t state : partitionStates) + caseValues.push_back(APInt(8, state)); + LLVM::SwitchOp::create(b, b.getLoc(), warpState, defaultBlock, ValueRange(), + caseValues, partitionBlocks, + SmallVector(partitionBlocks.size())); + + // Now add synchronization around the default regions. + for (size_t i = 0; i < wsOps.size(); ++i) { + WarpSpecializeOp ws = wsOps[i]; + auto &stateMap = warpToState[i]; + Block *before = ws->getBlock(); + Block *after = b.splitBlock(before, ws->getIterator()); + TritonLLVMIRRewriter b(ws.getLoc(), OpBuilder::atBlockEnd(before)); + Type int8Type = b.getIntegerType(8); + Value statePtrWs = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + for (auto [j, state] : llvm::enumerate(stateMap)) { + Value stateVal = b.i8_val(state); + b.store(stateVal, b.gep(ptrTy, int8Type, statePtrWs, LLVM::GEPArg(j))); + } + + // Store the captures if there are any. + auto partOp = ws.getPartitionOp(); + if (partOp.getNumOperands()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(partOp.getOperandTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + for (auto [j, arg] : + llvm::zip(llvm::seq(partOp.getNumOperands()), + partOp.getOperands())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, j}); + b.store(arg, ptr, /*align=*/1); + } + } + + // First barrier releases the waiting warpgroups. The second barrier ensures + // they have read the captures before the memory is released upon entry. + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + callbacks.reallocRegisters(b, ws, + RegisterReallocPhase::DefaultPartitionStart, 0); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + LLVM::BrOp::create(b, b.getLoc(), &ws.getDefaultRegion().front()); + + ws.getDefaultRegion().walk([&, ws = ws](WarpYieldOp op) mutable { + TritonLLVMIRRewriter b(op.getLoc(), op); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + callbacks.reallocRegisters(b, ws, + RegisterReallocPhase::DefaultPartitionEnd, 0); + b.replaceOpWithNewOp(op, op.getOperands(), after); + }); + after->getParent()->getBlocks().splice(after->getIterator(), + ws.getDefaultRegion().getBlocks()); + + // Replace the results. + auto outputs = after->addArguments( + ws.getResultTypes(), + SmallVector(ws.getNumResults(), ws.getLoc())); + ws.replaceAllUsesWith(outputs); + ws.erase(); + } + + // Signal all warp groups to exit. + func.walk([&](LLVM::ReturnOp op) { + TritonLLVMIRRewriter b(op.getLoc(), op); + Type int8Type = b.getIntegerType(8); + Value statePtrExit = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value cst = b.i8_val(partitionStateCounter); + for (int32_t i : llvm::seq(maxNumWarps)) + b.store(cst, b.gep(ptrTy, int8Type, statePtrExit, LLVM::GEPArg(i))); + callbacks.createAllBarrier(b, switchLoopBarrierIdx); + }); + b.setInsertionPointToStart(switchExit); + LLVM::ReturnOp::create(b, b.getLoc(), ValueRange()); + + return success(); +} diff --git a/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..5a3c379304 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonInstrumentToLLVM + InstrumentationToLLVM.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TritonIR + TritonGPUIR + TritonInstrumentIR + TritonNvidiaGPUIR + NVGPUIR +) diff --git a/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp new file mode 100644 index 0000000000..75cac06ec2 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -0,0 +1,371 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" +// #include "triton/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +// #include "third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include + +namespace { + +namespace tt = mlir::triton; +namespace ttg = tt::gpu; +namespace tti = mlir::triton::instrument; +namespace ttng = mlir::triton::nvidia_gpu; + +//////////////////////////////////////////// +// Utility functions +//////////////////////////////////////////// + +Value createMemDescToI32(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, + ttg::MemDescType memDescTy, Value sharedMemStruct) { + TritonLLVMOpBuilder b(loc, rewriter); + auto i32Ty = rewriter.getIntegerType(32); + if (isa(memDescTy.getMemorySpace())) { + return b.ptrtoint(i32Ty, sharedMemStruct); + } + assert(isa(memDescTy.getEncoding()) && + "Unsupported memory encoding"); + Type srcElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, sharedMemStruct, + srcElemTy, rewriter); + auto offset = smemObj.getShmemOffset(loc, rewriter, memDescTy); + auto elemSize = srcElemTy.getIntOrFloatBitWidth() / 8; + offset = b.mul(offset, b.i32_val(elemSize)); + return b.add(offset, b.ptrtoint(i32Ty, smemObj.getBase())); +} + +std::tuple +createIfBlock(ConversionPatternRewriter &b, Location loc, Value cnd) { + // #prevBlock + // if (condition) { + // #ifBlock + // } + // #thenBlock + Block *prevBlock = b.getInsertionBlock(); + Block *ifBlock = b.splitBlock(prevBlock, b.getInsertionPoint()); + + // Split a block after the call. + Block *thenBlock = b.splitBlock(ifBlock, ifBlock->begin()); + b.setInsertionPointToEnd(ifBlock); + LLVM::BrOp::create(b, loc, thenBlock); + b.setInsertionPointToEnd(prevBlock); + LLVM::CondBrOp::create(b, loc, cnd, ifBlock, thenBlock); + b.setInsertionPointToStart(thenBlock); + + return {prevBlock, ifBlock, thenBlock}; +} + +//////////////////////////////////////////// +// Patterns +//////////////////////////////////////////// + +struct AssertInThreadOpConversion + : public ConvertOpToLLVMPattern { + explicit AssertInThreadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(tti::ExperimentalAssertInThreadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector condElems = + unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto condTy = condElems[0].getType(); + bool check_any = adaptor.getCheckAny(); + + // TODO: Check that all the values are available in the current thread + + Value condition = check_any ? b.int_val(condTy.getIntOrFloatBitWidth(), 0) + : b.int_val(condTy.getIntOrFloatBitWidth(), 1); + + assert(condTy.isSignedInteger() || + condTy.isSignlessInteger() && + "Unsupported type for assert_in_thread"); + Value zero = LLVM::ConstantOp::create(rewriter, loc, condTy, + rewriter.getZeroAttr(condTy)); + for (auto elem : condElems) { + if (check_any) { + condition = b.or_(condition, elem); + } else { + condition = b.and_(condition, elem); + } + } + + // Invert the condition - assert will be hit if the condition is true + condition = b.xor_(condition, b.int_val(condTy.getIntOrFloatBitWidth(), 1)); + + llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + + b.barrier(ttg::AddrSpace::None); + } + rewriter.eraseOp(op); + return success(); + } + + void llAssert(Operation *op, Value condition, StringRef message, + ConversionPatternRewriter &rewriter) const { + + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + while (auto nameLoc = dyn_cast(loc)) + loc = nameLoc.getChildLoc(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + // Print the message only for the first thread + Value threadId = getThreadId(*b.builder, loc); + Value zero = b.int_val(threadId.getType().getIntOrFloatBitWidth(), 0); + Value threadIdIsZero = b.icmp_eq(threadId, zero); + condition = b.and_(condition, threadIdIsZero); + + auto [prevBlock, ifBlock, thenBlock] = + createIfBlock(rewriter, loc, condition); + + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + rewriter.setInsertionPointToStart(thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct BufferDescriptorsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(tti::ExperimentalBufferDescriptorsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto encoding = + cast(op.getResult().getType().getEncoding()); + auto offsets = adaptor.getOffsets(); + auto lengths = adaptor.getLengths(); + assert(offsets.size() == lengths.size() && "Mismatched descriptor arrays"); + + auto tensorType = cast(op.getResult().getType()); + + SmallVector offsetVals; + offsetVals.reserve(offsets.size()); + for (int32_t offset : offsets) + offsetVals.push_back(static_cast(offset)); + Value pointerTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, offsetVals); + + TritonLLVMOpBuilder b(loc, rewriter); + auto i64Ty = rewriter.getIntegerType(64); + Value baseTensor = nullptr; + if (op.getMemType() == tti::MemType::SHARED_MEM) { + auto func = op->getParentOfType(); + Value base = getSharedMemoryBase(rewriter, func); + baseTensor = triton::SplatOp::create(rewriter, loc, tensorType, base); + } else { + assert(op.getMemType() == tti::MemType::TENSOR_MEM && + "Unsupported memory type"); + Value basePtr = nvgpu::TensorMemoryBaseAddress::create(rewriter, loc); + Value base = b.ptrtoint(i64Ty, basePtr); + baseTensor = triton::SplatOp::create(rewriter, loc, tensorType, base); + } + + pointerTensor = arith::AddIOp::create( + rewriter, loc, pointerTensor.getType(), pointerTensor, baseTensor); + + SmallVector maskVals(offsets.size(), 0xffffffffu); + Value maskTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, maskVals); + Value trimmedPointers = arith::AndIOp::create( + rewriter, loc, pointerTensor.getType(), pointerTensor, maskTensor); + + SmallVector lengthVals; + lengthVals.reserve(lengths.size()); + for (int32_t length : lengths) + lengthVals.push_back(static_cast(static_cast(length)) + << 32); + Value lengthTensor = + createInitializedIntArrayTensor(rewriter, loc, encoding, lengthVals); + + auto bufDescriptors = + arith::OrIOp::create(rewriter, loc, trimmedPointers.getType(), + trimmedPointers, lengthTensor); + rewriter.replaceOp(op, bufDescriptors); + return success(); + } + + Value createInitializedIntArrayTensor(OpBuilder &builder, Location loc, + BlockedEncodingAttr encoding, + ArrayRef values) const { + int64_t size = values.size(); + assert(llvm::isPowerOf2_64(size) && "Expected power of 2"); + auto tensorType = + RankedTensorType::get({size}, builder.getIntegerType(64), encoding); + SmallVector apInts = llvm::to_vector( + llvm::map_range(values, [](uint64_t v) { return APInt(64, v); })); + auto denseAttr = DenseElementsAttr::get(tensorType, apInts); + return arith::ConstantOp::create(builder, loc, tensorType, denseAttr); + } + + Value getSharedMemoryBase(ConversionPatternRewriter &rewriter, + FunctionOpInterface func) const { + Location loc = func.getLoc(); + Value basePtr = LLVM::getStackPointer(rewriter, func); + auto i64Ty = rewriter.getIntegerType(64); + TritonLLVMOpBuilder b(loc, rewriter); + return b.ptrtoint(i64Ty, basePtr); + } +}; + +// struct LockAcquireOpConversion +// : public ConvertOpToLLVMPattern { +// using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +// LogicalResult matchAndRewrite(tti::ExperimentalLockAcquireOp op, +// OpAdaptor adaptor, +// ConversionPatternRewriter &b) const override +// { +// Location loc = op.getLoc(); +// b.setInsertionPoint(op); +// Value lock = op.getLock(); + +// Type elType = cast(lock.getType()).getPointeeType(); +// assert(elType == b.getI32Type() && "Expected i32 lock element type"); + +// // Build: do { old = atom.global.acquire.cas.b32 [lock], 0, 1; } while +// (old +// // != 0); +// Block *prevBlock2 = b.getInsertionBlock(); +// Block *whileBlock = b.splitBlock(prevBlock2, b.getInsertionPoint()); +// Block *endBlock = b.splitBlock(whileBlock, whileBlock->begin()); +// b.setInsertionPointToEnd(prevBlock2); +// Value elect = mlir::LLVM::NVIDIA::createElectPredicateWarp0(loc, b); +// if (op.getPred()) { +// elect = arith::AndIOp::create(b, loc, elect, op.getPred()); +// } +// LLVM::CondBrOp::create(b, loc, elect, whileBlock, endBlock); + +// b.setInsertionPointToEnd(whileBlock); + +// auto i32 = b.getI32Type(); +// Value zero = +// arith::ConstantOp::create(b, loc, i32, b.getIntegerAttr(i32, 0)); +// Value one = +// arith::ConstantOp::create(b, loc, i32, b.getIntegerAttr(i32, 1)); + +// // Inline PTX CAS: old = atom.global.acquire.gpu.cas.b32 [lock], 0, 1 +// // Use converted lock pointer from adaptor for addressing +// PTXBuilder ptx; +// auto *dstOpr = ptx.newOperand("=r", /*init=*/true); +// auto *ptrOpr = ptx.newAddrOperand(adaptor.getLock(), "l"); +// auto *cmpOpr = ptx.newOperand(zero, "r"); +// auto *valOpr = ptx.newOperand(one, "r"); +// auto &atom = *ptx.create("atom"); +// atom.global().o("acquire").o("gpu").o("cas").o("b32"); +// atom(dstOpr, ptrOpr, cmpOpr, valOpr); +// Value old = ptx.launch(b, loc, i32); + +// // while (old != 0) loop +// Value cond = +// arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, old, zero); +// LLVM::CondBrOp::create(b, loc, cond, whileBlock, endBlock); + +// b.setInsertionPointToStart(endBlock); +// triton::gpu::BarrierOp::create(b, loc, +// triton::gpu::AddrSpace::GlobalRead | +// triton::gpu::AddrSpace::GlobalWrite); +// b.eraseOp(op); +// return success(); +// } +// }; + +struct LockReleaseOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult matchAndRewrite(tti::ExperimentalLockReleaseOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &b) const override { + Location loc = op.getLoc(); + b.setInsertionPoint(op); + Value lock = op.getLock(); + if (op.getPred()) { + auto [prevBlock, ifBlock, thenBlock] = + createIfBlock(b, loc, op.getPred()); + b.setInsertionPointToStart(ifBlock); + } + + Type elType = cast(lock.getType()).getPointeeType(); + assert(elType == b.getI32Type() && "Expected i32 lock element type"); + + triton::gpu::BarrierOp::create(b, loc, + triton::gpu::AddrSpace::GlobalRead | + triton::gpu::AddrSpace::GlobalWrite); + Value zero = + arith::ConstantOp::create(b, loc, elType, b.getIntegerAttr(elType, 0)); + triton::AtomicRMWOp::create(b, loc, elType, RMWOp::XCHG, lock, zero, + nullptr, MemSemantic::ACQUIRE_RELEASE, + MemSyncScope::GPU); + b.eraseOp(op); + return success(); + } +}; + +struct MemDescToI32OpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + tti::ExperimentalMemDescToI32Op>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(tti::ExperimentalMemDescToI32Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value converted = + createMemDescToI32(rewriter, loc, getTypeConverter(), + op.getMemdesc().getType(), adaptor.getMemdesc()); + rewriter.replaceOp(op, converted); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateInstrumentationToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter); + // patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); +} diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 0000000000..ed879c7dd5 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonToTritonGPU + RelayoutTritonGPU.cpp + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + ProtonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp new file mode 100644 index 0000000000..7ea2e94a99 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp @@ -0,0 +1,132 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_RELAYOUTTRITONGPU +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +namespace ttng = triton::nvidia_gpu; + +// Given a tensor and its representation in tensor memory, determine its +// distributed layout. +RankedTensorType getTMEMTensorLayout(const TypeConverter *tc, + RankedTensorType type, MemDescType memdesc, + unsigned numWarps) { + type = cast(tc->convertType(type)); + auto cgaLayout = getCGALayout(type.getEncoding()); + auto encoding = + ttng::getDefaultLayoutForTmemLdSt(memdesc, numWarps, cgaLayout); + return type.cloneWithEncoding(encoding); +} + +struct TMEMLoadOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = getTypeConverter()->convertType(op.getType()); + RankedTensorType type = getTMEMTensorLayout( + typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op)); + rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); }); + if (type == resultType) + return success(); + + rewriter.setInsertionPointAfter(op); + auto cvt = ConvertLayoutOp::create(rewriter, op.getLoc(), resultType, + op.getResult()); + // Bypass the rewriter to avoid issues with the conversion framework's + // tracking of conditional replacements. + // See https://github.com/llvm/llvm-project/commit/504b50789602 + op.getResult().replaceAllUsesExcept(cvt, cvt); + return success(); + } +}; + +struct TMEMStoreOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType type = + getTMEMTensorLayout(typeConverter, op.getSrc().getType(), + op.getDst().getType(), lookupNumWarps(op)); + Value src = + ConvertLayoutOp::create(rewriter, op.getLoc(), type, adaptor.getSrc()); + rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); }); + return success(); + } +}; + +struct TMEMAllocOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getSrc()) + return success(); + RankedTensorType type = getTMEMTensorLayout( + typeConverter, op.getSrc().getType(), op.getType(), lookupNumWarps(op)); + Value src = + ConvertLayoutOp::create(rewriter, op.getLoc(), type, adaptor.getSrc()); + rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); }); + return success(); + } +}; + +class RelayoutTritonGPU + : public triton::impl::RelayoutTritonGPUBase { +public: + using RelayoutTritonGPUBase::RelayoutTritonGPUBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + int numWarps = lookupNumWarps(mod); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs, /*enableSourceRemat=*/true); + TritonGPUConversionTarget target(*context, typeConverter); + target.addDynamicallyLegalDialect( + [&](Operation *op) { + return TritonGPUConversionTarget::isDynamicallyLegal(op, + typeConverter); + }); + + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + patterns.insert< + // clang-format off + GatherScatterOpPattern, + GatherScatterOpPattern, + TMEMLoadOpPattern, + TMEMStoreOpPattern, + TMEMAllocOpPattern + // clang-format on + >(typeConverter, context); + + ConversionConfig config; + config.allowPatternRollback = false; + if (failed( + applyPartialConversion(mod, target, std::move(patterns), config))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 0000000000..129b86fba6 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,186 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs, + bool enableSourceRemat) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return tensorType.cloneWithEncoding(encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + if (enableSourceRemat) { + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) -> Value { + return UnrealizedConversionCastOp::create(builder, loc, tensorType, + inputs) + .getResult(0); + }); + } + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + triton::gpu::ConvertLayoutOp::create(builder, loc, tensorType, inputs); + return cast.getResult(); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) { return isDynamicallyLegal(op, typeConverter); }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); + addDynamicallyLegalOp([](triton::FuncOp funcOp) -> bool { + for (auto arg : funcOp.getArguments()) { + if (auto tensor = dyn_cast(arg.getType())) { + if (!tensor.getEncoding()) + return false; + } + } + return true; + }); +} + +bool TritonGPUConversionTarget::isDynamicallyLegal( + Operation *op, const TypeConverter &typeConverter) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; +} + +// This function returns the layout to use for gather/scatter indices. The +// `gather4` and `scatter4` TMA instructions require 4 consecutive indices. +// Thus, threads issuing these instructions must have all 4 index elements +// available. +static RankedTensorType getNewIndicesType(RankedTensorType type, + unsigned numThreads, + unsigned numWarps, unsigned numCTAs) { + assert(type.getRank() == 1); + auto enc = cast(type.getEncoding()); + auto ctx = type.getContext(); + + // Technically any layout where we have a pack of 4 neighbouring elements plus + // broadcasted over the warp dimension is okay but for now we just pick a + // layout. + std::array sizePerThread{1, 4}; + std::array threadsPerWarp = {numThreads, 1}; + std::array order = {1, 0}; + std::array warpsPerCta = {1, numWarps}; + auto cgaLayout = + CGAEncodingAttr::fromSplitParams(ctx, {1, numCTAs}, {1, numCTAs}, order); + + auto parentEncoding = BlockedEncodingAttr::get( + ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, cgaLayout); + auto newEncoding = SliceEncodingAttr::get(ctx, /*dim=*/0, parentEncoding); + if (enc == newEncoding) + return {}; + + return type.cloneWithEncoding(newEncoding); +} + +// Function for converting any gather or scatter op that requires a specific +// index layout. This also handles converting result types if there are any. +static LogicalResult convertGatherScatterIndices(Operation *op, + OpOperand &indices, + ConversionPatternRewriter &b) { + auto type = cast(indices.get().getType()); + RankedTensorType newType = getNewIndicesType( + type, lookupThreadsPerWarp(b), lookupNumWarps(op), lookupNumCTAs(op)); + if (!newType) + return failure(); + Value index = + ConvertLayoutOp::create(b, op->getLoc(), newType, indices.get()); + indices.set(index); + return success(); +} + +LogicalResult impl::convertGatherScatterOp( + Operation *op, ValueRange operands, OpOperand &xOffsetsMutable, + const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + LogicalResult result = success(); + rewriter.modifyOpInPlace(op, [&] { + for (auto [operand, value] : llvm::zip(op->getOpOperands(), operands)) + operand.set(value); + for (OpResult result : op->getOpResults()) + result.setType(typeConverter.convertType(result.getType())); + result = convertGatherScatterIndices(op, xOffsetsMutable, rewriter); + }); + return result; +} diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 0000000000..780d0fd5a0 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,839 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPU +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (isa(retShapedType)) { + assert(value && "expected a dense elements attribute"); + // This is a hack. We just want to add encoding. + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + auto newRank = retShape.size(); + // return encoding + auto retSizePerThread = llvm::to_vector(argEncoding.getSizePerThread()); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = to_vector(argEncoding.getThreadsPerWarp()); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA()); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto ctaLl = argEncoding.getCGALayout().getLinearLayout(); + auto kBlock = *ctaLl.getInDimNames().begin(); + auto *ctx = kBlock.getContext(); + auto newDim = standardOutDimNames(ctx, newRank)[newRank - 1]; + ctaLl *= LinearLayout::identity1D(1, kBlock, newDim); + // Move last dim to op.getAxis(). nb is this a std::rotate? + auto newOrder = to_vector(llvm::seq(newRank)); + for (int i = newRank - 1; i >= op.getAxis() + 1; --i) { + std::swap(newOrder[i], newOrder[i - 1]); + } + ctaLl = transposeLinearLayout(ctaLl, newOrder); + auto retCGALayout = CGAEncodingAttr::get(ctx, std::move(ctaLl)); + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCGALayout); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = argType.cloneWithEncoding(newArgEncoding); + // construct new op + auto newSrc = triton::gpu::ConvertLayoutOp::create( + rewriter, op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = origType.cloneWithEncoding(dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = aType.cloneWithEncoding(encoding); + a = triton::gpu::ConvertLayoutOp::create(rewriter, a.getLoc(), dstType, + a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = bType.cloneWithEncoding(encoding); + b = triton::gpu::ConvertLayoutOp::create(rewriter, b.getLoc(), dstType, + b); + } + c = triton::gpu::ConvertLayoutOp::create(rewriter, c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread()); + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCGALayout()); + auto newRetType = retType.cloneWithEncoding(newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + auto layout = defaultEnc.getCGALayout().getLinearLayout(); + auto kBlock = StringAttr::get(getContext(), "block"); + auto newDim = standardOutDimNames(getContext(), rank)[rank - 1]; + layout *= LinearLayout::identity1D(1, kBlock, newDim); + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CGAEncodingAttr::get(getContext(), std::move(layout))); + srcTy = srcTy.cloneWithEncoding(srcEnc); + src = ConvertLayoutOp::create(rewriter, op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = op.getType().cloneWithEncoding(srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = triton::ReduceOp::create( + rewriter, op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = + triton::ScanOp::create(rewriter, op.getLoc(), adaptor.getOperands(), + adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +struct TritonMapElementwisePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MapElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + SmallVector resultTys; + auto err = converter->convertTypes(op.getResults().getType(), resultTys); + if (failed(err)) { + return err; + } + + auto newMapOp = triton::MapElementwiseOp::create( + rewriter, op.getLoc(), resultTys, adaptor.getOperands(), op.getPack()); + addNamedAttrs(newMapOp, adaptor.getAttributes()); + + auto &newScalarOp = newMapOp.getScalarOp(); + rewriter.cloneRegionBefore(op.getScalarOp(), newScalarOp, + newScalarOp.end()); + rewriter.replaceOp(op, newMapOp.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + TypeConverter::SignatureConversion result(op.getNumArguments()); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + // Convert just the entry block. The remaining unstructured control flow is + // converted by br patterns. + if (!newOp.getBody().empty()) + rewriter.applySignatureConversion(&newOp.getBody().front(), result, + converter); + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + // clang-format off + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonBroadcastPattern, + TritonCatPattern, + TritonJoinOpPattern, + TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonReducePattern, + GenericOpPattern, + TritonScanPattern, + GenericOpPattern, + GenericOpPattern, + TritonExpandDimsPattern, + TritonTransPattern, + TritonDotPattern, + TritonMapElementwisePattern, + GatherScatterOpPattern, + GatherScatterOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + // this assumes the right layout will be set later for dot scaled. + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonFuncOpPattern + // clang-format on + >(typeConverter, context); +} +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = scf::WhileOp::create(rewriter, op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} + +class ConvertTritonToTritonGPU + : public triton::impl::ConvertTritonToTritonGPUBase< + ConvertTritonToTritonGPU> { +public: + using ConvertTritonToTritonGPUBase::ConvertTritonToTritonGPUBase; + + void runOnOperation() override { + if (target.getValue().empty()) { + mlir::emitError( + getOperation().getLoc(), + "'convert-triton-to-tritongpu' requires 'target' option to be set"); + return signalPassFailure(); + } + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs, enableSourceRemat); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + patterns.insert>(typeConverter, context); + + Builder b(&getContext()); + mod->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps)); + mod->setAttr(AttrNumThreadsPerWarp, b.getI32IntegerAttr(threadsPerWarp)); + mod->setAttr(AttrNumCTAsName, b.getI32IntegerAttr(numCTAs)); + mod->setAttr(AttrTargetName, b.getStringAttr(this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/mthreads/lib/Dialect/CMakeLists.txt b/third_party/mthreads/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..19ca22ec3b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(TritonInstrument) +add_subdirectory(Gluon) +add_subdirectory(NVGPU) +add_subdirectory(NVWS) diff --git a/third_party/mthreads/lib/Dialect/Gluon/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Gluon/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/Gluon/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Gluon/IR/CMakeLists.txt new file mode 100644 index 0000000000..315f033e22 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(GluonIR + Dialect.cpp + + DEPENDS + GluonTableGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Dialect/Gluon/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/Gluon/IR/Dialect.cpp new file mode 100644 index 0000000000..0a18ec8522 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/IR/Dialect.cpp @@ -0,0 +1,138 @@ +#include "triton/Dialect/Gluon/IR/Dialect.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton::gpu; +namespace gluon = mlir::triton::gluon; + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/Gluon/IR/Dialect.cpp.inc" +#include "triton/Dialect/Gluon/IR/GluonAttrDefs.cpp.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/Gluon/IR/Ops.cpp.inc" + +namespace { + +// Layout inference for AutoEncodingAttr -> always propagate AutoEncodingAttr to +// results +struct GluonInferLayoutInterface : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult inferAutoEncoding(Attribute operandEncoding, + Attribute &resultEncoding) const { + if (!isa( + operandEncoding)) + return failure(); + resultEncoding = operandEncoding; + return success(); + } + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional loc) const override { + return inferAutoEncoding(operandEncoding, resultEncoding); + } + + LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef shape, + ArrayRef order, Attribute &resultEncoding, + std::optional loc) const override { + return inferAutoEncoding(operandEncoding, resultEncoding); + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + return inferAutoEncoding(operandEncoding, resultEncoding); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute resultEncoding, + std::optional location) const override { + return inferAutoEncoding(operandEncoding, resultEncoding); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + return success(); + } + + LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, + std::optional loc) const override { + return success(expected == got); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + return inferAutoEncoding(srcEnc, dstEnc); + } + + LogicalResult + inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + return inferAutoEncoding(srcEnc, dstEnc); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + return inferAutoEncoding(srcEnc, dstEnc); + } + + LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute srcEnc, + Attribute &dstEnc, bool fwdInference, + std::optional loc) const override { + return inferAutoEncoding(srcEnc, dstEnc); + } +}; +} // namespace + +namespace mlir::triton::gluon { + +void GluonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/Gluon/IR/GluonAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Gluon/IR/Ops.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); +} + +void SetAutoLayoutOp::build(OpBuilder &builder, OperationState &state, + Attribute enc, Value value) { + auto resTy = cast(value.getType()).cloneWithEncoding(enc); + return build(builder, state, resTy, value); +} + +LogicalResult SetAutoLayoutOp::verify() { + if (!isa(getSrc().getType().getEncoding())) { + return emitOpError("input tensor must have an auto layout type"); + } + auto dstEncoding = getType().getEncoding(); + if (!dstEncoding) + return emitOpError("result tensor must have an encoding"); + if (isa(dstEncoding)) + return emitOpError("result type must not be auto layout"); + return success(); +} + +} // namespace mlir::triton::gluon diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Gluon/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..0e43d594c2 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(GluonTransforms + Canonicalize.cpp + Inline.cpp + ResolveAutoEncodings.cpp + SimplifyControlFlow.cpp + InferCoalescedEncodings.cpp + InferLayoutUtils.cpp + + DEPENDS + GluonTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + GluonIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/Canonicalize.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/Canonicalize.cpp new file mode 100644 index 0000000000..6b8514df75 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/Canonicalize.cpp @@ -0,0 +1,64 @@ +#include "mlir/IR/OperationSupport.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; +namespace gluon = mlir::triton::gluon; + +namespace mlir::triton::gluon { +#define GEN_PASS_DEF_GLUONCANONICALIZE +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" +} // namespace mlir::triton::gluon + +namespace { +struct Canonicalize : public gluon::impl::GluonCanonicalizeBase { + void runOnOperation() override; +}; +} // namespace + +void Canonicalize::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(&getContext()); + + // Populate `arith` and `scf` canonicalizers. + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect( + arith::ArithDialect::getDialectNamespace())) + op.getCanonicalizationPatterns(patterns, ctx); + for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect( + scf::SCFDialect::getDialectNamespace())) + op.getCanonicalizationPatterns(patterns, ctx); + for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect( + cf::ControlFlowDialect::getDialectNamespace())) + op.getCanonicalizationPatterns(patterns, ctx); + populateForOpDeadArgumentElimination(patterns); + + // Populate select Triton canonicalization patterns. The important patterns to + // EXCLUDE are those that modify layouts, especially `ConvertLayoutOp` + // patterns. + LoadOp::getCanonicalizationPatterns(patterns, ctx); + StoreOp::getCanonicalizationPatterns(patterns, ctx); + BroadcastOp::getCanonicalizationPatterns(patterns, ctx); + ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx); + ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx); + ttg::WarpSpecializePartitionsOp::getCanonicalizationPatterns(patterns, ctx); + + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); +} diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp new file mode 100644 index 0000000000..3544064ac7 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp @@ -0,0 +1,112 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Gluon/Transforms/InferLayoutUtils.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/PriorityWorklist.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/xxhash.h" + +#define DEBUG_TYPE "gluon-infer-coalesced-encodings" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace ttg = mlir::triton::gpu; + +namespace mlir::triton::gluon { + +#define GEN_PASS_DEF_GLUONINFERCOALESCEDENCODINGSPASS +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" + +namespace { + +ttg::CGAEncodingAttr getDefaultCGALayout(RankedTensorType refTensorType, + int numCTAs) { + // TODO support numCTAs > 1 + assert(numCTAs == 1 && "only numCTAs == 1 is supported for now"); + return ttg::CGAEncodingAttr::get1CTALayout(refTensorType.getContext(), + refTensorType.getShape().size()); +} + +bool isCoalescedEncodingTensorType(Type ty) { + auto tensorTy = dyn_cast(ty); + return tensorTy && isa(tensorTy.getEncoding()); +} + +LogicalResult inferCoalescedLayout(ModuleOp &mod) { + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + + // infer function-level coalesced layout + for (auto &op : *mod.getBody()) { + auto func = dyn_cast(&op); + if (!func) + continue; + + // 1. for every load/store with coalesced encoding, + // infer coalesced encoding for ptrs + // + llvm::SmallVector> seedEncodings; + func.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + // we only consider those with coalesced encoding + if (!isCoalescedEncodingTensorType(ptr.getType())) + return; + + // build a coalesced encoding + int numWarps = ttg::lookupNumWarps(curr); + int numCTAs = ttg::lookupNumCTAs(curr); + auto tensorType = cast(ptr.getType()); + auto cgaLayout = getDefaultCGALayout(tensorType, numCTAs); + auto shapePerCTA = ttg::getShapePerCTA(cgaLayout.getCTASplitNum(), + tensorType.getShape()); + auto layout = + ttg::buildCoalescedEncoding(axisInfoAnalysis, curr, numWarps, + threadsPerWarp, cgaLayout, shapePerCTA); + // set seed value + for (auto value : curr->getOperands()) + seedEncodings.push_back({value, layout}); + }); + + // 2. propagate Coalesced Layout forward/backward + // + // for backward slice, it doesn't cross the set_auto_layout boundary + // i.e. gl.set_auto_layout(val, gl.CoalescedLayout()) + // -> gl.set_auto_layout(val, a concrete coalesced layout) + // then ResolveAutoLayoutPass will handle the rest + // + if (failed(inferLayout(func, isCoalescedEncodingTensorType, seedEncodings))) + return failure(); + } + return success(); +} + +} // anonymous namespace + +class GluonInferCoalescedEncodingsPass + : public impl::GluonInferCoalescedEncodingsPassBase< + GluonInferCoalescedEncodingsPass> { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + + if (failed(inferCoalescedLayout(moduleOp))) + return signalPassFailure(); + + if (failed(doubleCheckEncodings(moduleOp, isCoalescedEncodingTensorType))) + return signalPassFailure(); + } +}; +} // namespace mlir::triton::gluon diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp new file mode 100644 index 0000000000..bff4e64a4b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp @@ -0,0 +1,251 @@ +#include "triton/Dialect/Gluon/Transforms/InferLayoutUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PriorityWorklist.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/xxhash.h" + +#define DEBUG_TYPE "gluon-infer-layout-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton::gluon { + +namespace { +struct LayoutInfo { + Attribute encoding; + // Some operations can infer one of many encodings, + // we model this by setting the mayVary flag on encodings + // derived from these ops. + // If "may vary" is set then we allow conflicts, and when + // resolving conflicts we prefer encodings that are not allowed to vary. + bool mayVary = false; + + operator bool() { return bool(encoding); } +}; + +uint64_t hashWithMemo(Attribute attr, + llvm::MapVector &hashMemo) { + auto it = hashMemo.find(attr); + if (it != hashMemo.end()) { + return it->second; + } + + // llvm::hash_value is not stable, so instead we hash the string repr of the + // attribute + std::string str; + llvm::raw_string_ostream os(str); + attr.print(os); + auto hash = llvm::xxh3_64bits(str); + hashMemo.try_emplace(attr, hash); + return hash; +} + +bool compare(Attribute a, Attribute b, + llvm::MapVector &hashMemo) { + if (a == b) + return false; + + return hashWithMemo(a, hashMemo) > hashWithMemo(b, hashMemo); +} + +LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op, + llvm::MapVector &hashMemo) { + // Sort inputs so this operation is commutative + if (compare(lhs.encoding, rhs.encoding, hashMemo)) { + std::swap(lhs, rhs); + } + if (lhs.mayVary) + return rhs; + if (rhs.mayVary) + return lhs; + if (lhs.encoding == rhs.encoding) + return lhs; + op->emitOpError("found conflicting encodings for value:\n ") + << lhs.encoding << "\nand\n " << rhs.encoding; + return {}; +} + +bool encodingsMayVary(Operation *op) { + return isa(op); +} + +LogicalResult +updateEncoding(ArrayRef values, LayoutInfo info, FuncOp *func, + llvm::MapVector &valueToEncoding, + llvm::PriorityWorklist &worklist, + llvm::MapVector &hashMemo) { + for (auto value : values) { + auto [it, inserted] = valueToEncoding.insert({value, info}); + if (!inserted) { + auto defOp = value.getDefiningOp(); + auto op = defOp ? defOp : func->getOperation(); + auto combine = combineInfo(it->second, info, op, hashMemo); + if (!combine) + return failure(); + if (combine == it->second) + continue; + it->second = combine; + } + LLVM_DEBUG({ + DBGS() << "Setting value:\n\t" << value << "\nto encoding:\n\t" + << it->second.encoding << "\n"; + }); + worklist.insert(value); + } + return success(); +} +} // namespace + +LogicalResult inferLayout( + FuncOp func, llvm::function_ref typeCheck, + const llvm::SmallVector> &seedEncodings) { + // Disallow auto encoding accross function call boundaries + for (auto argTy : func.getArgumentTypes()) { + if (typeCheck(argTy)) { + return func->emitError( + "Functions taking auto encoding must be fully inlined"); + } + } + for (auto resultTy : func.getResultTypes()) { + if (typeCheck(resultTy)) + return func->emitError( + "Functions returning auto encoding must be fully inlined"); + } + + // set seed + llvm::MapVector valueToEncoding; + llvm::PriorityWorklist worklist; + llvm::MapVector hashMemo; + for (auto &[value, encoding] : seedEncodings) { + if (failed(updateEncoding({value}, LayoutInfo{encoding, false}, &func, + valueToEncoding, worklist, hashMemo))) + return failure(); + } + + // Propagate encodings through the graph until fixed point, or conflict + while (!worklist.empty()) { + auto val = worklist.pop_back_val(); + auto info = valueToEncoding[val]; + assert(info); + + // Propagate to users + for (OpOperand &use : val.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto tiedArgs = getTiedArgs(op, use.getOperandNumber() - offset); + if (failed(updateEncoding(tiedArgs, info, &func, valueToEncoding, + worklist, hashMemo))) + return failure(); + } else if (isa(op)) { + auto tiedArgs = getTiedArgs(op, use.getOperandNumber()); + if (failed(updateEncoding(tiedArgs, info, &func, valueToEncoding, + worklist, hashMemo))) + return failure(); + } else { + auto dstEnc = inferDstEncoding(op, info.encoding); + if (dstEnc) { + bool mayVary = info.mayVary || encodingsMayVary(op); + LayoutInfo dstInfo{dstEnc, mayVary}; + if (failed(updateEncoding(llvm::to_vector_of(op->getResults()), + dstInfo, &func, valueToEncoding, worklist, + hashMemo))) + return failure(); + } + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(val)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto tiedArgs = getTiedArgs(definingOp, opResult.getResultNumber()); + if (failed(updateEncoding(tiedArgs, info, &func, valueToEncoding, + worklist, hashMemo))) + return failure(); + } else { + auto srcEncoding = inferSrcEncoding(definingOp, info.encoding); + if (srcEncoding) { + bool mayVary = info.mayVary || encodingsMayVary(definingOp); + LayoutInfo srcInfo{srcEncoding, mayVary}; + llvm::SmallVector tensorOperands; + for (auto operand : definingOp->getOperands()) + if (isa(operand.getType())) + tensorOperands.push_back(operand); + + if (failed(updateEncoding(tensorOperands, srcInfo, &func, + valueToEncoding, worklist, hashMemo))) + return failure(); + } + } + } else if (auto blockArg = dyn_cast(val)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto tiedArgs = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + if (failed(updateEncoding(tiedArgs, info, &func, valueToEncoding, + worklist, hashMemo))) + return failure(); + } + } + } + + // Transfer propagated encodings into the graph + auto ctx = func.getContext(); + for (auto &[val, info] : valueToEncoding) { + assert(typeCheck(val.getType())); + auto existingTy = cast(val.getType()); + auto ty = existingTy.cloneWithEncoding(info.encoding); + val.setType(ty); + + if (auto opResult = dyn_cast(val)) { + if (auto constantOp = dyn_cast(opResult.getOwner())) { + auto value = cast(constantOp.getValueAttr()); + auto newValue = + SplatElementsAttr::get(ty, value.getSplatValue()); + constantOp.setValueAttr(newValue); + } + } + } + return success(); +} + +LogicalResult doubleCheckEncodings(ModuleOp &mod, + llvm::function_ref typeCheck) { + auto res = mod.walk([&](Operation *op) -> WalkResult { + for (auto resTy : op->getResultTypes()) { + if (typeCheck(resTy)) { + return op->emitOpError("Failed to infer return type"); + } + } + return success(); + }); + if (res.wasInterrupted()) + return failure(); + + res = mod.walk([&](Block *block) -> WalkResult { + for (auto argTy : block->getArgumentTypes()) { + if (typeCheck(argTy)) { + return block->getParentOp()->emitError( + "Failed to infer block argument type"); + } + } + return success(); + }); + if (res.wasInterrupted()) + return failure(); + return success(); +} + +} // namespace mlir::triton::gluon diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/Inline.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/Inline.cpp new file mode 100644 index 0000000000..0dd7d26c73 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/Inline.cpp @@ -0,0 +1,29 @@ +#include "triton/Dialect/Gluon/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; +namespace gluon = mlir::triton::gluon; + +namespace mlir::triton::gluon { +#define GEN_PASS_DEF_GLUONINLINE +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" +} // namespace mlir::triton::gluon + +namespace { +struct Inline : public gluon::impl::GluonInlineBase { + void runOnOperation() override; +}; +} // namespace + +void Inline::runOnOperation() { + mlir::PassManager pm(&getContext()); + pm.addPass(createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) { + pm.addPass(gluon::createGluonSimplifyControlFlow()); + })); + if (failed(pm.run(getOperation()))) + return signalPassFailure(); +} diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp new file mode 100644 index 0000000000..c7b775cb7a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp @@ -0,0 +1,71 @@ +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Gluon/Transforms/InferLayoutUtils.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PriorityWorklist.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" + +namespace ttg = mlir::triton::gpu; + +namespace mlir::triton::gluon { + +#define GEN_PASS_DEF_GLUONRESOLVEAUTOENCODINGSPASS +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "gluon-resolve-auto-encodings" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { +bool isAutoEncodingTensorType(Type ty) { + auto tensorTy = dyn_cast(ty); + return tensorTy && isa(tensorTy.getEncoding()); +} +LogicalResult inferAutoLayout(ModuleOp &mod) { + for (auto &op : *mod.getBody()) { + auto func = dyn_cast(&op); + if (!func) + continue; + + // Set seed values from set_auto_layout ops + llvm::SmallVector> seedEncodings; + func.walk([&](gluon::SetAutoLayoutOp op) { + seedEncodings.push_back({op.getSrc(), op.getType().getEncoding()}); + }); + + if (failed(inferLayout(func, isAutoEncodingTensorType, seedEncodings))) + return failure(); + } + return success(); +} +} // anonymous namespace + +class GluonResolveAutoEncodingsPass + : public impl::GluonResolveAutoEncodingsPassBase< + GluonResolveAutoEncodingsPass> { +public: + using BaseT = + impl::GluonResolveAutoEncodingsPassBase; + using BaseT::BaseT; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // Do layout inference + if (failed(inferAutoLayout(m))) + return signalPassFailure(); + + // Cleanup set_auto_layout ops + m.walk([&](gluon::SetAutoLayoutOp op) { + assert(op.getSrc().getType() == op.getType()); + op.getResult().replaceAllUsesWith(op.getSrc()); + op->erase(); + }); + + if (failed(doubleCheckEncodings(m, isAutoEncodingTensorType))) + return signalPassFailure(); + } +}; +} // namespace mlir::triton::gluon diff --git a/third_party/mthreads/lib/Dialect/Gluon/Transforms/SimplifyControlFlow.cpp b/third_party/mthreads/lib/Dialect/Gluon/Transforms/SimplifyControlFlow.cpp new file mode 100644 index 0000000000..c0a6b40f68 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Gluon/Transforms/SimplifyControlFlow.cpp @@ -0,0 +1,49 @@ +#include "mlir/IR/OperationSupport.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" + +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace triton; + +namespace mlir::triton::gluon { +#define GEN_PASS_DEF_GLUONSIMPLIFYCONTROLFLOW +#include "triton/Dialect/Gluon/Transforms/Passes.h.inc" +} // namespace mlir::triton::gluon + +namespace { +struct SimplifyControlFlow + : public gluon::impl::GluonSimplifyControlFlowBase { + void runOnOperation() override; +}; +} // namespace + +void SimplifyControlFlow::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(&getContext()); + + // Populate `scf` and `cf` canonicalizers. + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect( + scf::SCFDialect::getDialectNamespace())) + op.getCanonicalizationPatterns(patterns, ctx); + for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect( + cf::ControlFlowDialect::getDialectNamespace())) + op.getCanonicalizationPatterns(patterns, ctx); + populateForOpDeadArgumentElimination(patterns); + + GreedyRewriteConfig config; + // This is intended to run before AutoLayouts are resolved, in which case + // CSEing constants can lead to additional layout conflicts. + config.enableConstantCSE(false); + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); +} diff --git a/third_party/mthreads/lib/Dialect/NVGPU/CMakeLists.txt b/third_party/mthreads/lib/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/lib/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..1fd118d2be --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(NVGPUIR + Dialect.cpp + + DEPENDS + NVGPUTableGen + NVGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect +) diff --git a/third_party/mthreads/lib/Dialect/NVGPU/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/NVGPU/IR/Dialect.cpp new file mode 100644 index 0000000000..14b1cd1e45 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/NVGPU/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::nvgpu; + +void mlir::triton::nvgpu::NVGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc" diff --git a/third_party/mthreads/lib/Dialect/NVWS/CMakeLists.txt b/third_party/mthreads/lib/Dialect/NVWS/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/NVWS/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/NVWS/IR/CMakeLists.txt new file mode 100644 index 0000000000..0473fd1b0b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(NVWSIR + Dialect.cpp + Ops.cpp + + DEPENDS + NVWSTableGen + NVWSAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Dialect/NVWS/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/NVWS/IR/Dialect.cpp new file mode 100644 index 0000000000..44db71b766 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/IR/Dialect.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/IR/Dialect.cpp.inc" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` +// clang-format on + +using namespace mlir; +using namespace mlir::triton::nvws; + +void mlir::triton::nvws::NVWSDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/NVWS/IR/NVWSAttrDefs.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/NVWS/IR/Types.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/NVWS/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/NVWS/IR/NVWSAttrDefs.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/NVWS/IR/Types.cpp.inc" diff --git a/third_party/mthreads/lib/Dialect/NVWS/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/NVWS/IR/Ops.cpp new file mode 100644 index 0000000000..9ee11425e5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/IR/Ops.cpp @@ -0,0 +1,185 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/TypeRange.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/NVWS/IR/NVWSAttrEnums.cpp.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/NVWS/IR/NVWSOpInterfaces.cpp.inc" +#include "triton/Dialect/NVWS/IR/Ops.cpp.inc" + +namespace mlir::triton::nvws { + +LogicalResult ArefCreateOp::verify() { + SmallVector dims; + for (auto operand : getOperands()) { + SmallVector users(operand.user_begin(), operand.user_end()); + if (!llvm::all_of(users, [](Operation *op) { + return isa(op); + })) + return emitError("Aref buffer is used elsewhere, Aref cannot guarantee " + "async safety"); + auto type = operand.getType(); + if (auto mType = dyn_cast(type)) { + dims.push_back(mType.getShape()[0]); + } else if (auto rType = dyn_cast(type)) { + dims.push_back(rType.getShape()[0]); + } else { + return emitError("Aref is sliced, but input type isn't supported."); + } + } + if (!llvm::all_equal(dims)) + return emitError("Leading dims of sliced aref inputs don't match."); + + return success(); +} + +template +static std::optional verifySlice(T &origType, T &newType) { + if (!origType || !newType) + return "MLIR Types don't match"; + if (isa( + origType.getEncoding())) { + if (origType.getElementType() != newType.getElementType() || + origType.getRank() != newType.getRank()) { + return "Ranks don't match for TensorMemoryScalesEncodingAttr"; + } + for (size_t i = 0, e = newType.getShape().size(); i < e; i++) { + if (origType.getShape()[i] != newType.getShape()[i]) + return "Dimensions don't match for TensorMemoryScalesEncodingAttr"; + } + } else { + if (origType.getElementType() != newType.getElementType() || + origType.getRank() - 1 != newType.getRank()) { + return "Ranks don't match"; + } + for (size_t i = 0, e = newType.getShape().size(); i < e; i++) { + if (origType.getShape()[i + 1] != newType.getShape()[i]) + return "Dimensions don't match"; + } + } + return std::nullopt; +} + +std::optional static arefEnterVerify( + ArefType aref, mlir::ValueTypeRange resultTypes) { + auto typeArray = aref.getBaseType(); + if (typeArray.size() != resultTypes.size()) + return "Aref has different number of arguments than enter"; + // This should probably rely on the memdescSubsliceOp verifier? + for (auto [orig, arg] : llvm::zip(typeArray, resultTypes)) { + if (auto origT = dyn_cast(orig)) { + auto argT = dyn_cast(arg); + if (auto result = verifySlice(origT, argT)) + return result; + } else if (auto origT = dyn_cast(orig)) { + auto argT = dyn_cast(arg); + if (auto result = verifySlice(origT, argT)) + return result; + } else { + return "Slicing not Implemented for this type"; + } + } + return std::nullopt; +} + +LogicalResult ArefPutEnterOp::verify() { + if (auto result = + arefEnterVerify(getAref().getType(), getBuffers().getType())) + return emitError(*result); + return success(); +} + +LogicalResult ArefGetEnterOp::verify() { + if (auto result = + arefEnterVerify(getAref().getType(), getBuffers().getType())) + return emitError(*result); + return success(); +} + +LogicalResult WarpGroupOp::verify() { + auto numWarps = getNumWarps(); + auto regions = getRegions(); + if (numWarps.size() != regions.size()) + return emitError("Must supply numWarps for each Warp Group."); + if (getResults().size() > 0) { + if (regions.size() == 0) { + return emitError("Must have at least one region when there are results."); + } + if (!isa( + regions.front()->front().getTerminator())) { + return emitError("When nvws.warp_group op has results, the first region " + "should be terminated by nvws.warp_group.yield op."); + } + auto yieldOp = + cast(regions.front()->front().getTerminator()); + if (getResults().size() != yieldOp.getNumOperands()) { + return emitError( + "Mismatch in the number of results returned by nvws.warp_group op " + "and the number of the operands of the corresponding " + "nvws.warp_group.yield op in the first region."); + } + } + return success(); +} + +ParseResult WarpGroupOp::parse(OpAsmParser &p, OperationState &result) { + auto ctx = p.getBuilder().getContext(); + + SMLoc operandLoc = p.getCurrentLocation(); + if (p.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + SmallVector partitionNumWarps; + while (succeeded(p.parseOptionalKeyword( + ("partition" + Twine(partitionNumWarps.size()).str())))) { + SMLoc regionLoc = p.getCurrentLocation(); + if (p.parseKeyword("num_warps") || p.parseLParen() || + p.parseInteger(partitionNumWarps.emplace_back()) || p.parseRParen() || + p.parseRegion(*result.addRegion())) + return failure(); + } + + result.addAttribute(getNumWarpsAttrName(result.name), + p.getBuilder().getDenseI32ArrayAttr(partitionNumWarps)); + + return success(); +} + +void WarpGroupOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + {getNumWarpsAttrName()}); + + for (auto [i, region, numWarps] : + llvm::enumerate(getPartitionRegions(), getNumWarps())) { + p.printNewline(); + p << "partition" << i; + p << " num_warps(" << numWarps << ") "; + p.printRegion(region, /*printEntryBlockArgs=*/false); + } +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num, + TokenLoadType loadType) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num, loadType); +} + +void ArefPutEnterOp::setStage(Value stage) { getStageMutable().assign(stage); } +void ArefPutExitOp::setStage(Value stage) { getStageMutable().assign(stage); } +void ArefGetExitOp::setStage(Value stage) { getStageMutable().assign(stage); } +void ArefGetEnterOp::setStage(Value stage) { getStageMutable().assign(stage); } +void ArefBufferOp::setStage(Value stage) { getStageMutable().assign(stage); } + +} // namespace mlir::triton::nvws diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp new file mode 100644 index 0000000000..168229528a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp @@ -0,0 +1,563 @@ +/* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "Utilities.h" +#include "mlir/Analysis/TopologicalSortUtils.h" + +#include "Utilities.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; +using namespace mlir::triton::nvws; + +#define DEBUG_TYPE "nvws-lower-aref" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSASSIGNSTAGEPHASE +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" +namespace { +template struct AssignStagePhase { + struct StagePhase { + Value stage; + Value phase; + Value token; + }; + Value aref; + int partitionId; + DenseMap, int> tokToStagePosMap; + + AssignStagePhase(Value aref, int partitionId) + : aref(aref), partitionId(partitionId) {} + + T getTypedOp(Operation *op) { + if (auto opT = dyn_cast(op)) { + if (opT.getAref() == aref) { + if (!hasPartition(op) || + llvm::is_contained(getPartitionIds(op), partitionId)) + return opT; + } + } + return {}; + } + bool isBufferUsed(ArefBufferOp bufOp, Value token) { + if (!bufOp) + return false; + if (bufOp.getAref() != this->aref) + return false; + return token == bufOp.getToken(); + } + + bool analyzeArefUseInBlock(Block *block, Value token) { + for (auto &op : *block) { + if (getTypedOp(&op) || isBufferUsed(dyn_cast(op), token)) { + return true; + } else if (auto forOp = dyn_cast(op)) { + Value newTok; + if (auto pos = findValuePosInRange(forOp.getInitArgs(), token)) { + newTok = forOp.getRegionIterArgs()[*pos]; + } + if (analyzeArefUseInBlock(forOp.getBody(), newTok)) + return true; + } else if (auto ifOp = dyn_cast(op)) { + if (analyzeArefUseInBlock(ifOp.thenBlock(), token)) + return true; + if (ifOp.elseBlock() && analyzeArefUseInBlock(ifOp.elseBlock(), token)) + return true; + } + } + return false; + } + + void assignArefIndexInForOp(scf::ForOp forOp, StagePhase &index) { + + Value newTok; + if (auto pos = findValuePosInRange(forOp.getInitArgs(), index.token)) { + newTok = forOp.getRegionIterArgs()[*pos]; + } + // find uses of arefs in forOp body + if (!analyzeArefUseInBlock(forOp.getBody(), newTok)) + return; + + // add extra iterArgs to the forOp + SmallVector extraIterArgs{index.stage, index.phase}; + SmallVector arefIndexRefs{&index.stage, &index.phase}; + llvm::MapVector arefTokenRefs; + + if (auto pos = findValuePosInRange(forOp.getInitArgs(), index.token)) { + // keep reference of the token position to latest token value + // we will need it update with the value returned from forOp + arefTokenRefs[*pos] = &index.token; + // update token value with iter argument + index.token = forOp.getRegionIterArgs()[*pos]; + } + // create new forOp with extra iterArgs + OpBuilder builder(forOp); + size_t nArgs = forOp.getRegionIterArgs().size(); + + assert(hasPartition(forOp)); + auto forOpIds = getPartitionIds(forOp); + auto forOpOutputsIds = getPartitionOutputs(forOp); + forOp = addIterArgsToLoop(builder, forOp, extraIterArgs); + + // update arefIndex with iterArgs in the forOp body + for (size_t idx = nArgs; idx < forOp.getRegionIterArgs().size(); ++idx) + *arefIndexRefs[idx - nArgs] = forOp.getRegionIterArgs()[idx]; + + // assign arefIndex in the forOp body + auto indexInBlock = assignArefIndexInBlock(forOp.getBody(), index); + + // update yieldOp to return new indexes + SmallVector extraYieldArgs; + // associate token with stage positional argument in the iterArgs & yieldOp + // we will need this in propagateStage function that will assign stage + // to arefBuffer and arefExit ops + extraYieldArgs.push_back(indexInBlock.stage); + if (index.phase) + extraYieldArgs.push_back(indexInBlock.phase); + appendToForOpYield(forOp, extraYieldArgs); + tokToStagePosMap[{forOp, index.token}] = nArgs; + tokToStagePosMap[{forOp.getBody()->getTerminator(), indexInBlock.token}] = + nArgs; + + // update partitions of the forOp + for (auto arg : extraYieldArgs) { + SetVector argIds; + if (auto defOp = arg.getDefiningOp()) { + if (defOp->getNumRegions() == 0) { + // if there is defOp, use partitions of defOp + assert(hasPartition(defOp)); + argIds = getPartitionIds(defOp); + } else { + // if op has region, it returns result, get partition from result + auto pos = findValuePosInRange(defOp->getResults(), arg); + argIds = getPartitionOutputs(defOp)[*pos]; + } + } else { + // otherwise it is a block-arg, use partitions of users + for (auto user : arg.getUsers()) { + if (isa(user)) + continue; + assert(hasPartition(user)); + auto ids = getPartitionIds(user); + argIds.insert(ids.begin(), ids.end()); + } + } + forOpIds.insert(argIds.begin(), argIds.end()); + forOpOutputsIds.push_back(argIds); + } + setPartition(forOp, forOpIds); + setPartitionOutputs(forOp, forOpOutputsIds); + + // update arefIndex with results from newForOp + for (size_t idx = nArgs; idx < forOp.getRegionIterArgs().size(); ++idx) + *arefIndexRefs[idx - nArgs] = forOp.getResult(idx); + for (auto [idx, arefTokenRef] : arefTokenRefs) + *arefTokenRef = forOp.getResult(idx); + } + + void assignArefIndexInIfOp(scf::IfOp ifOp, StagePhase &index) { + + auto useInThenBlock = analyzeArefUseInBlock(ifOp.thenBlock(), index.token); + auto useInElseBlock = + ifOp.elseBlock() ? analyzeArefUseInBlock(ifOp.elseBlock(), index.token) + : false; + if (!useInThenBlock && !useInElseBlock) + return; + + // add extra results to the ifOp + SmallVector extraIfResults{index.stage.getType(), + index.phase.getType()}; + SmallVector arefIndexRefs{&index.stage, &index.phase}; + + // create new ifOp with extra results + OpBuilder builder(ifOp); + size_t nArgs = ifOp.getResults().size(); + auto newIfOp = replaceIfOpWithNewSignature(builder, ifOp, extraIfResults); + + // assign arefIndex in then-body + auto thenIndex = assignArefIndexInBlock(newIfOp.thenBlock(), index); + + // assign arefIndex in else-body + auto elseIndex = newIfOp.elseBlock() + ? assignArefIndexInBlock(newIfOp.elseBlock(), index) + : index; + + // update yieldOp to return new indexes + auto thenYieldOp = newIfOp.thenYield(); + auto elseYieldOp = newIfOp.elseYield(); + // insert new indexes to the yieldOp + llvm::MapVector arefTokenRefs; + + // find token pos in yieldOp and make a reference to arefIndexMap value + if (auto pos = + findValuePosInRange(thenYieldOp->getOperands(), index.token)) { + arefTokenRefs[*pos] = &index.token; + } + if (auto pos = + findValuePosInRange(elseYieldOp->getOperands(), index.token)) { + arefTokenRefs[*pos] = &index.token; + } + + tokToStagePosMap[{newIfOp.thenYield(), thenIndex.token}] = + thenYieldOp.getNumOperands(); + tokToStagePosMap[{newIfOp.elseYield(), elseIndex.token}] = + elseYieldOp.getNumOperands(); + thenYieldOp->insertOperands(thenYieldOp.getNumOperands(), thenIndex.stage); + elseYieldOp->insertOperands(elseYieldOp.getNumOperands(), elseIndex.stage); + thenYieldOp->insertOperands(thenYieldOp.getNumOperands(), thenIndex.phase); + elseYieldOp->insertOperands(elseYieldOp.getNumOperands(), elseIndex.phase); + + assert(hasPartition(ifOp)); + auto ifOpIds = getPartitionIds(ifOp); + auto ifOpOutputsIds = getPartitionOutputs(ifOp); + ifOp.erase(); + + SetVector stageIds; + // at least one of the then/else block must have producing op + for (auto arg : {thenIndex.stage, elseIndex.stage}) { + if (auto defOp = arg.getDefiningOp()) { + auto argIds = getPartitionIds(defOp); + stageIds.insert(argIds.begin(), argIds.end()); + } + } + SetVector phaseIds; + for (auto arg : {thenIndex.phase, elseIndex.phase}) { + if (auto defOp = arg.getDefiningOp()) { + auto argIds = getPartitionIds(defOp); + phaseIds.insert(argIds.begin(), argIds.end()); + } + } + ifOpOutputsIds.push_back(stageIds); + ifOpOutputsIds.push_back(phaseIds); + setPartition(newIfOp, ifOpIds); + setPartitionOutputs(newIfOp, ifOpOutputsIds); + + // update arefIndex with results from newIfOp + for (size_t idx = nArgs; idx < newIfOp.getResults().size(); ++idx) + *arefIndexRefs[idx - nArgs] = newIfOp.getResult(idx); + for (auto [idx, arefTokenRef] : arefTokenRefs) + *arefTokenRef = newIfOp.getResult(idx); + } + + StagePhase assignArefIndexInBlock(Block *block, StagePhase index) { + for (auto &op : llvm::make_early_inc_range(*block)) { + if (auto opT = getTypedOp(&op)) { + ImplicitLocOpBuilder builder(opT.getLoc(), opT); + std::optional> partitionIds; + if (hasPartition(&op)) + partitionIds = getPartitionIds(&op); + auto wsTag = getWarpSpecializeTag(&op); + auto stageCluster = getStageCluster(&op); + + auto createInto = [&](auto opTy, auto... args) { + using ty = decltype(opTy); + auto op = triton::gpu::createInto( + builder, builder.getLoc(), partitionIds, stageCluster, + std::forward(args)...); + if (wsTag) + setWarpSpecializeTag(op, *wsTag); + return op; + }; + + auto nextStage = createInto(arith::AddIOp{}, index.stage, + createInto(arith::ConstantIntOp{}, 1, 32)); + auto arefBuf = opT.getAref() + .template getDefiningOp() + .getOperand(0); + auto depth = getArefDepth(cast(arefBuf.getType())); + + auto cnd = + createInto(arith::CmpIOp{}, arith::CmpIPredicate::eq, nextStage, + createInto(arith::ConstantIntOp{}, depth, 32)); + auto zero = createInto(arith::ConstantIntOp{}, 0, 32); + index.stage = createInto(arith::SelectOp{}, cnd, zero, nextStage); + + auto nextPhase = createInto(arith::XOrIOp{}, index.phase, + createInto(arith::ConstantIntOp{}, 1, 32)); + index.phase = + createInto(arith::SelectOp{}, cnd, nextPhase, index.phase); + + index.token = opT.getToken(); + opT.getStageMutable().assign(index.stage); + opT.getPhaseMutable().assign(index.phase); + } else if (auto forOp = dyn_cast(op)) { + assignArefIndexInForOp(forOp, index); + } else if (auto ifOp = dyn_cast(op)) { + assignArefIndexInIfOp(ifOp, index); + } + } + + return index; + } + + void propagateStage(Value token, Value stage, + DenseSet &visited) { + for (auto &tokUse : token.getUses()) { + auto owner = tokUse.getOwner(); + if (visited.contains(owner)) + continue; + visited.insert(owner); + if (auto stageOp = dyn_cast(owner)) { + if (auto blk = dyn_cast(stage)) { + assert(hasPartition(stageOp)); + auto stageOpIds = getPartitionIds(stageOp); + auto forOp = cast(blk.getOwner()->getParentOp()); + auto pos = findValuePosInRange(forOp.getRegionIterArgs(), stage); + assert(pos); + + // update op partitions + assert(hasPartition(forOp)); + auto forOpIds = getPartitionIds(forOp); + forOpIds.insert(stageOpIds.begin(), stageOpIds.end()); + setPartition(forOp, forOpIds); + + auto forOpOutputsIds = getPartitionOutputs(forOp); + forOpOutputsIds[*pos + 0].insert(stageOpIds.begin(), + stageOpIds.end()); + forOpOutputsIds[*pos + 1].insert(stageOpIds.begin(), + stageOpIds.end()); + setPartitionOutputs(forOp, forOpOutputsIds); + } + stageOp.setStage(stage); + } else if (auto forOp = dyn_cast(owner)) { + auto tokPos = tokUse.getOperandNumber() - forOp.getNumControlOperands(); + auto iterTok = forOp.getRegionIterArg(tokPos); + auto stagePos = tokToStagePosMap.at({forOp, iterTok}); + propagateStage(iterTok, forOp.getRegionIterArgs()[stagePos], visited); + } else if (auto yieldOp = dyn_cast(owner)) { + auto tokPos = tokUse.getOperandNumber(); + auto stagePos = tokToStagePosMap.at({yieldOp, token}); + auto parentOp = yieldOp->getParentOp(); + propagateStage(parentOp->getResult(tokPos), + parentOp->getResult(stagePos), visited); + } + } + } + + static LogicalResult run(ArefCreateOp arefOp) { + std::set partitionIds; + for (auto user : arefOp->getUsers()) { + // Each partition requires its own stage/phase tracking for proper + // multi-user handling; collect partition IDs in which this aref is used + if (isa(user)) { + if (hasPartition(user)) { + auto ids = getPartitionIds(user); + partitionIds.insert(ids.begin(), ids.end()); + } + } + } + if (partitionIds.empty()) { + // if partitionIds is an empty set, it means aref ops used outside ttg.ws + // so we to insert a dummy partitionId for this aref, since we still need + // to assign correct phase + partitionIds.insert({0, 0}); + } + + // initialize indexes + StagePhase index; + ImplicitLocOpBuilder b(arefOp.getLoc(), arefOp); + b.setInsertionPointAfter(arefOp); + auto depth = + getArefDepth(cast(arefOp.getOperand(0).getType())); + index.stage = arith::ConstantIntOp::create(b, depth - 1, 32); + + static_assert(std::is_same_v || + std::is_same_v, + "ArefPutEnterOp or ArefGetEnterOp expected"); + auto initPhase = std::is_same_v ? 0 : 1; + index.phase = arith::ConstantIntOp::create(b, initPhase, 32); + + for (auto partitionId : partitionIds) { + // assign stage/phase to enter/exit Ops in each partition aref is used + AssignStagePhase arefIndex(arefOp.getResult(), partitionId); + + // assign stage/phase to enterOps + arefIndex.assignArefIndexInBlock(arefOp->getBlock(), index); + + // propagate stage to exitOps following enterOp token + for (auto user : arefOp->getUsers()) + if (auto enterOp = dyn_cast(user); + enterOp && (!hasPartition(enterOp) || + getPartitionIds(enterOp).front() == partitionId)) { + DenseSet visited; + arefIndex.propagateStage(enterOp.getToken(), enterOp.getStage(), + visited); + } + } + + return success(); + } +}; + +void updateOutputWithDefaultPartition(Operation *op, int pos) { + auto opIds = getPartitionIds(op); + opIds.insert(0); + setPartition(op, opIds); + + auto opOutputsIds = getPartitionOutputs(op); + opOutputsIds[pos].insert(0); + setPartitionOutputs(op, opOutputsIds); +} + +void visitBackwardSlice(scf::ForOp wsLoop, Value value, + std::function callback, + DenseSet &visited) { + if (!visited.insert(value).second) + return; + + if (auto blockArg = dyn_cast(value)) { + if (auto forOp = dyn_cast(blockArg.getOwner()->getParentOp())) { + if (forOp->hasAttr(kWarpSpecializeAttrName)) + return; + auto pos = findValuePosInRange(forOp.getRegionIterArgs(), value); + assert(pos); + visitBackwardSlice(wsLoop, forOp.getInitArgs()[*pos], callback, visited); + } + } else if (auto defOp = value.getDefiningOp(); + isa(defOp)) { + auto pos = findValuePosInRange(defOp->getResults(), value); + assert(pos); + updateOutputWithDefaultPartition(defOp, *pos); + if (auto ifOp = dyn_cast(defOp)) { + visitBackwardSlice(wsLoop, ifOp.thenYield()->getOperand(*pos), callback, + visited); + if (ifOp.elseBlock()) + visitBackwardSlice(wsLoop, ifOp.elseYield()->getOperand(*pos), callback, + visited); + visitBackwardSlice(wsLoop, ifOp.getCondition(), callback, visited); + } else { + auto forOp = cast(defOp); + visitBackwardSlice(wsLoop, + forOp.getBody()->getTerminator()->getOperand(*pos), + callback, visited); + // visit control operands of for-op + for (int idx = 0; idx < forOp.getNumControlOperands(); ++idx) { + auto control = forOp.getOperand(idx); + visitBackwardSlice(wsLoop, control, callback, visited); + } + } + } else if (wsLoop.getBody()->findAncestorOpInBlock(*defOp)) { + callback(defOp); + for (auto operand : defOp->getOperands()) { + visitBackwardSlice(wsLoop, operand, callback, visited); + } + } +} + +LogicalResult assignStagePhase(triton::FuncOp funcOp) { + SmallVector arefOps; + funcOp.walk([&](ArefCreateOp arefOp) { arefOps.push_back(arefOp); }); + for (auto arefOp : arefOps) { + if (failed(AssignStagePhase::run(arefOp))) + return failure(); + if (failed(AssignStagePhase::run(arefOp))) + return failure(); + } + + auto callback = [&](Operation *op) { + if (!isa(op)) { + assert(hasPartition(op)); + auto partitionIds = getPartitionIds(op); + partitionIds.insert(0); + setPartition(op, partitionIds); + } + }; + + funcOp.walk([&](scf::ForOp forOp) { + DenseSet visited; + if (forOp->hasAttr(kWarpSpecializeAttrName)) { + for (auto result : forOp.getResults()) { + // if result is of scalar type and is used outside of for-op, visit + // all dependencies and assign default partition to them + if (isa(result.getType()) && + !result.use_empty()) { + auto arg = forOp.getBody()->getTerminator()->getOperand( + result.getResultNumber()); + // Check if any users of this scalar result lack ttg.partition, or if + // it is used in another warp-specialized loop. If so, the scalar is + // consumed by the root partition outside the warp-specialized loop, + // requiring us to assign the default partition to all operations that + // compute this result. + bool assignDefaultPartition = + llvm::any_of(result.getUsers(), [&](Operation *user) { + return !hasPartition(user) || + (isa(user) && hasWarpSpecializeTag(user)); + }); + if (assignDefaultPartition) { + updateOutputWithDefaultPartition(forOp, result.getResultNumber()); + visitBackwardSlice(forOp, arg, callback, visited); + } + } + } + } + }); + return success(); +} + +// ---------------------------------------------------------------------------- + +} // anonymous namespace + +class NVWSAssignStagePhase + : public impl::NVWSAssignStagePhaseBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp m = getOperation(); + + m.walk([&](triton::FuncOp funcOp) { + if (failed(assignStagePhase(funcOp))) + signalPassFailure(); + }); + } +}; // namespace triton + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/NVWS/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..3b285c4efc --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(NVWSTransforms + LowerAref.cpp + LowerWarpGroup.cpp + InsertAref.cpp + Utilities.cpp + AssignStagePhase.cpp + InsertTmemAref.cpp + HoistTmemStore.cpp + + DEPENDS + NVWSTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + TritonNvidiaGPUIR + NVWSIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp new file mode 100644 index 0000000000..e645157196 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/HoistTmemStore.cpp @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include +#include + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +namespace ttng = triton::nvidia_gpu; + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSHOISTTMEMSTORE +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" +namespace { + +bool underWSLoop(Operation *op) { + scf::ForOp topLevelFor = op->getParentOfType(); + if (!topLevelFor) { + return false; + } + + if (topLevelFor->hasAttr(kWarpSpecializeAttrName)) { + return true; + } else { + while (auto outer = topLevelFor->getParentOfType()) { + topLevelFor = outer; + if (outer->hasAttr(kWarpSpecializeAttrName)) { + return true; + } + } + } + + return false; +} + +class FoldTmemStoreIntoAlloc : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc, + PatternRewriter &rewriter) const override { + if (alloc.getSrc() || !underWSLoop(alloc)) { + return failure(); + } + + for (auto user : alloc->getUsers()) { + if (auto store = dyn_cast(user)) { + auto storeSrc = store.getSrc(); + if (auto storeSrcDef = storeSrc.getDefiningOp()) { + DominanceInfo dom(storeSrcDef); + if (dom.dominates(storeSrcDef, alloc)) { + auto newAlloc = ttng::TMEMAllocOp::create( + rewriter, alloc.getLoc(), alloc.getResultTypes()[0], + rewriter.getType(), storeSrc); + + if (auto allocTok = alloc.getToken()) { + allocTok.replaceAllUsesWith(newAlloc.getToken()); + } + if (auto storeTok = store.getToken()) { + storeTok.replaceAllUsesWith(newAlloc.getToken()); + } + if (hasPartition(store)) { + // The alloc op can have multiple partitions at this point. But + // aref-tmem-insert requires a single owner, which should be the + // partiton that tmem_store belongs to. + setPartition(newAlloc, getPartitionIds(store)); + } + rewriter.eraseOp(store); + rewriter.replaceOp(alloc, newAlloc); + return success(); + } + } + } + } + + return failure(); + } +}; + +std::optional> +getUniqueUserLoopAndMMA(ttng::TMEMAllocOp tmemAlloc) { + auto tok = tmemAlloc.getToken(); + if (!tok || !tok.hasOneUse()) + return std::nullopt; + auto loop = dyn_cast(*tok.getUsers().begin()); + if (!loop) + return std::nullopt; + auto loopTok = loop.getBody()->getArgument( + tok.getUses().begin()->getOperandNumber() - 2); + if (!loopTok.hasOneUse()) + return std::nullopt; + auto mma = dyn_cast(*loopTok.getUsers().begin()); + if (mma) + return std::make_pair(loop, mma); + return std::nullopt; +} + +// Check if this alloc is used by an MMA op with useD initialized to false +bool canRemoveTmemStore(ttng::TMEMAllocOp tmemAlloc) { + auto opt = getUniqueUserLoopAndMMA(tmemAlloc); + if (!opt) + return false; + auto [loop, mma] = *opt; + auto useD = dyn_cast(mma.useAccumulator()); + if (!useD) + return false; + auto parent = useD.getParentBlock()->getParentOp(); + if (parent != loop) + return false; + auto loopInit = loop.getInitArgs()[useD.getArgNumber() - 1]; + auto val = getBoolFromConstant(loopInit); + return val && val.value() == false; +} + +bool canProveExecuteOnce(scf::ForOp forOp) { + auto getAssumedBound = [&](Value v) -> std::optional { + mlir::ForwardSliceOptions opt; + SetVector slice; + (void)getForwardSlice(v, &slice, opt); + + // For simplicity, we only handle an assume op directly operating on v. It's + // possible to support more general cases, but they require a range + // analysis. + for (auto op : slice) { + if (auto assumeOp = dyn_cast(op)) { + auto cond = assumeOp.getCond(); + if (auto cmpOp = cond.getDefiningOp(); + cmpOp && (cmpOp.getLhs() == v || cmpOp.getRhs() == v)) { + if (auto bound = getBoundFromCmpOp(cmpOp, v)) { + return *bound; + } + } + } + } + return std::nullopt; + }; + + auto getConstIntBound = [&](Value v) { + unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(v.getType()); + if (auto cst = getConstantIntValue(getAsOpFoldResult(v))) { + APInt apVal = {bitWidth, static_cast(*cst), /*signed*/ true}; + return mlir::ConstantIntRanges::constant(apVal); + } else if (auto assumedBound = getAssumedBound(v)) { + return *assumedBound; + } else { + APInt min = APInt::getSignedMinValue(bitWidth); + APInt max = APInt::getSignedMaxValue(bitWidth); + return mlir::ConstantIntRanges::range(min, max, true); + } + }; + + auto lbBound = getConstIntBound(forOp.getLowerBound()); + auto ubBound = getConstIntBound(forOp.getUpperBound()); + return mlir::intrange::evaluatePred(mlir::intrange::CmpPredicate::slt, + lbBound, ubBound) + .value_or(false); +} + +bool hoistTmemAlloc(ttng::TMEMAllocOp allocToHoist) { + // extra loop nest + SmallVector loopNest; + auto currentForOp = allocToHoist->getParentOfType(); + while (currentForOp && !currentForOp->hasAttr(kWarpSpecializeAttrName)) { + loopNest.push_back(currentForOp); + currentForOp = currentForOp->getParentOfType(); + } + + if (!currentForOp) { + return false; + } + + loopNest.push_back(currentForOp); + + { + // Check if hoisting across all loop nests is valid. Hoisting is invalid + // when the inner loop that does MMA executes variable number of times + // depending on the outer loop variables, and some instances of the inner + // loops never execute while others do. So we hoist across loop nests only + // in the following cases: + // 1. The loop iteration counts for all loops do not depend on their outer + // loop variables. + // 2. If there is a loop whose iteration count depends on outer loop + // varaibles, there is an llvm.intr.assume op from which we can prove that + // the number of iteration is greater than zero. + auto opt = getUniqueUserLoopAndMMA(allocToHoist); + if (!opt) { + return false; + } + + SmallVector innerLoopNest{opt->first}; + innerLoopNest.insert(innerLoopNest.begin(), loopNest.begin(), + loopNest.end() - 1); + + // Does the expression x depend on y? + auto dependOn = [](Value x, Value y) { + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + SetVector slice; + (void)getBackwardSlice(x, &slice, opt); + for (auto user : y.getUsers()) { + if (x.getDefiningOp() == user || slice.count(user)) { + return true; + } + } + return false; + }; + + for (auto [i, innerFor] : llvm::enumerate(innerLoopNest)) { + for (int j = i; j < loopNest.size(); ++j) { + auto outerForIter = loopNest[j].getInductionVar(); + if ((dependOn(innerFor.getLowerBound(), outerForIter) || + dependOn(innerFor.getUpperBound(), outerForIter)) && + !canProveExecuteOnce(innerFor)) { + // Cannot hoist this tmem alloc across the outer loop loopNest[j] + return false; + } + } + } + } + + // hoist to outside tt.warp_specialized loop + allocToHoist->moveBefore(currentForOp); + allocToHoist->removeAttr(kPartitionAttrName); + + Value token = allocToHoist.getToken(); + assert(token.hasOneUse()); + auto &tokenUse = *token.getUses().begin(); + auto tokenPos = + tokenUse.getOperandNumber() - currentForOp.getNumControlOperands(); + auto tokenPartition = getPartitionOutputs(tokenUse.getOwner())[tokenPos]; + + // thread token to for-op init/iter args from outer-to inner + std::reverse(loopNest.begin(), loopNest.end()); + for (auto &forOp : loopNest) { + OpBuilder b(forOp); + int nArgs = forOp.getRegionIterArgs().size(); + forOp = addIterArgsToLoop(b, forOp, {token}); + + // update partitions for the forOp + if (forOp->hasAttr(kPartitionOutputsAttrName)) { + auto partitionOuputs = getPartitionOutputs(forOp); + partitionOuputs.push_back(tokenPartition); + setPartitionOutputs(forOp, partitionOuputs); + } else { + setPartitionOutputs(forOp, {tokenPartition}); + } + auto partitions = getPartitionIds(forOp); + partitions.insert(tokenPartition.begin(), tokenPartition.end()); + setPartition(forOp, partitions); + + token = forOp.getRegionIterArg(nArgs); + } + + // set inner loop init_args with updated token + tokenUse.set(token); + + // get last produced token, the one w/o use + token = tokenUse.getOwner()->getResult(tokenPos); + while (!token.use_empty()) { + assert(token.hasOneUse()); + auto tokenUser = *token.getUsers().begin(); + if (auto load = dyn_cast(tokenUser)) { + token = load.getToken(); + } else if (auto store = dyn_cast(tokenUser)) { + token = store.getToken(); + } else { + auto mma = cast(tokenUser); + token = mma.getToken(); + } + } + + // append token to yield, from inner to outer loop + std::reverse(loopNest.begin(), loopNest.end()); + for (auto forOp : loopNest) { + appendToForOpYield(forOp, {token}); + setPartition(forOp.getBody()->getTerminator(), getPartitionIds(forOp)); + token = forOp->getResults().back(); + } + + return true; +} + +} // namespace + +class NVWSHoistTmemStore + : public impl::NVWSHoistTmemStoreBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp m = getOperation(); + + OpPassManager pm; + mlir::RewritePatternSet patterns(context); + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + + m.walk([&](scf::ForOp loop) { + if (loop->hasAttr(kWarpSpecializeAttrName)) { + SmallVector tmemAllocToHoist; + loop.walk([&](ttng::TMEMAllocOp tmemAlloc) { + if (tmemAlloc.getSrc() && canRemoveTmemStore(tmemAlloc)) { + tmemAllocToHoist.push_back(tmemAlloc); + } + }); + + for (auto alloc : tmemAllocToHoist) { + if (!hoistTmemAlloc(alloc)) { + SetVector mmaPartition; + mmaPartition.insert(1); + // tmem store remaining in the outer loop must belong to the MMA + // partition. This is required by aref-tmem-insert for correctly + // double buffering this accumulator. + setPartition(alloc, mmaPartition); + } + } + } + }); + } +}; // namespace triton + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertAref.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertAref.cpp new file mode 100644 index 0000000000..78982353cf --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertAref.cpp @@ -0,0 +1,641 @@ +#include "Utilities.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSINSERTAREF +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; +using namespace triton::nvws; + +struct ProducedValueInfo { + SetVector partitions; + Value result; +}; + +SmallVector getProducedValues(Operation *op, + Block *loopBody) { + SmallVector producedValues; + + if (!hasPartition(op)) + return {}; + + // For ops without regions, all results share the same partition IDs + auto partitionOutputs = op->getNumRegions() == 0 + ? SmallVector, 4>( + op->getNumResults(), getPartitionIds(op)) + : getPartitionOutputs(op); + + for (auto result : op->getResults()) { + if (isa(result.getType())) + continue; + producedValues.push_back( + {partitionOutputs[result.getResultNumber()], result}); + } + + return producedValues; +}; + +template +std::optional> isLoadAndAlloc(Value result) { + auto alloc = result.getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return std::nullopt; + if (auto load = alloc.getSrc().template getDefiningOp(); + load && getPartitionIds(alloc) == getPartitionIds(load)) { + // if alloc and load are in different partitions, they are treated as two + // different producer operations. + return std::make_pair(alloc, load); + } + return std::nullopt; +} + +// if result is defined by descriptor_load followed by alloc, return the alloc +// and the load ops as a pair. +template auto isDescLoadAndAlloc(Value result) { + return isLoadAndAlloc(result); +} + +template auto isGlobalLoadAndAlloc(Value result) { + return isLoadAndAlloc(result); +} + +RankedTensorType getTensorTypeFromScalar(OpBuilder &builder, Value scalar) { + auto mod = scalar.getParentRegion()->getParentOfType(); + auto nWarps = lookupNumWarps(mod); + auto threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int CTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + Attribute encoding = getDefaultBlockedEncoding(builder.getContext(), {1}, + nWarps, threadsPerWarp, CTAs); + return RankedTensorType::get({1}, scalar.getType(), encoding); +} + +ArefCreateOp createAref(OpBuilder &builder, ProducedValueInfo &producedValue) { + auto result = producedValue.result; + + auto getSmemDescType = [](RankedTensorType tensorType, Value tensorResult) { + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(tensorType.getContext()); + Attribute encoding = tensorResult && tensorResult.getDefiningOp() + ? getSharedEncoding(tensorResult.getDefiningOp()) + : getSharedEncoding(tensorType); + auto memDescType = + MemDescType::get(tensorType.getShape(), tensorType.getElementType(), + encoding, SharedMemorySpace); + return memDescType; + }; + + MemDescType memDescType; + if (result.getDefiningOp()) { + memDescType = dyn_cast(result.getType()); + } else if (auto tensorType = dyn_cast(result.getType())) { + memDescType = getSmemDescType(tensorType, result); + } else if (isa(result.getType())) { + auto tensorType = getTensorTypeFromScalar(builder, result); + memDescType = getSmemDescType(tensorType, Value()); + } else { + std::string msg = "createAref: unsupported produced value type: " + + mlir::debugString(result.getType()); + llvm::report_fatal_error(msg.c_str()); + } + + MemDescType arefBufType = getMultiBufferedType(memDescType, 1); + assert(isa(arefBufType.getMemorySpace())); + auto loc = result.getLoc(); + auto alloc = triton::nvws::createAlloc(builder, loc, arefBufType, Value()); + return createArefCreateOp(builder, {arefBufType}, {alloc->getResult(0)}, loc); +} + +int getTxCount(Operation *descOp) { + auto getTensorTypeAndDesc = + [](Operation *op) -> std::pair { + if (auto loadOp = dyn_cast(op)) { + return {loadOp.getType(), loadOp.getDesc()}; + } else if (auto gatherOp = dyn_cast(op)) { + return {gatherOp.getType(), gatherOp.getDesc()}; + } else { + llvm_unreachable("Unsupported operation type"); + } + }; + auto [tensorType, desc] = getTensorTypeAndDesc(descOp); + auto encoding = getEncodingFromDescriptor(descOp, tensorType, desc); + auto shapePerCTA = getShapePerCTA(encoding, tensorType.getShape()); + return product(shapePerCTA) * + getIntOrFloatOrPtrBitWidth(tensorType.getElementType()) / 8; +} + +void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp, + Value dataBuf, + SetVector const &producerPartitions, + Location loc) { + auto txCount = getTxCount(ttDescLoadOp); + if (auto descLoad = dyn_cast(ttDescLoadOp)) { + auto newDescLoad = triton::nvws::DescriptorLoadOp::create( + builder, loc, descLoad.getDesc(), descLoad.getIndices(), txCount, + dataBuf, descLoad.getCache(), descLoad.getEvict()); + newDescLoad->setAttrs(descLoad->getAttrs()); + setPartition(newDescLoad, producerPartitions); + } else if (auto descGather = + dyn_cast(ttDescLoadOp)) { + auto newDescGather = triton::nvws::DescriptorGatherOp::create( + builder, loc, descGather.getDesc(), descGather.getXOffsets(), + descGather.getYOffset(), txCount, dataBuf); + newDescGather->setAttrs(descGather->getAttrs()); + setPartition(newDescGather, producerPartitions); + } else { + llvm_unreachable("unknown descriptor op."); + } +} + +StageCluster getStageClusterForProducer(Value producedValue) { + if (auto arg = dyn_cast(producedValue)) { + Value prevProducedValue; + do { + prevProducedValue = producedValue; + auto terminator = arg.getOwner()->getTerminator(); + if (!isa(terminator)) { + return {}; + } + producedValue = terminator->getOperand(arg.getArgNumber() - 1); + arg = dyn_cast(producedValue); + } while (arg && prevProducedValue != producedValue); + } + + if (auto opt = isDescLoadAndAlloc(producedValue)) { + return getStageCluster(opt->second); + } else if (auto opt = isGlobalLoadAndAlloc(producedValue)) { + return getStageCluster(opt->second); + } else if (auto op = producedValue.getDefiningOp()) { + return getStageCluster(op); + } else { + return {}; + } +} + +SmallVector createArefPut(OpBuilder &builder, ArefCreateOp aref, + ProducedValueInfo producedValue) { + auto loc = producedValue.result.getLoc(); + auto arefBufType = cast(aref.getBuffers()[0].getType()); + Value result = producedValue.result; + Type dataBufType = getBufferViewType(arefBufType, /*mutable*/ true); + StageCluster stageCluster = getStageClusterForProducer(result); + + // elect a partition to put result into aref-buffer + SetVector producerPartitions; + producerPartitions.insert(producedValue.partitions.front()); + + Type token{builder.getType()}; + auto putEnterOp = triton::gpu::createInto( + builder, loc, producerPartitions, stageCluster, aref, + TypeRange{dataBufType}, token); + auto dataBuf = putEnterOp.getBuffers()[0]; + + auto producerKind = AsyncOp::NONE; + SmallVector staleOps; + if (auto opt = isDescLoadAndAlloc(result)) { + auto [alloc, descOp] = *opt; + createNVWSDescriptorLoadOp(builder, descOp, dataBuf, producerPartitions, + loc); + producerKind = AsyncOp::TMALoad; + staleOps.push_back(alloc); + staleOps.push_back(descOp); + } else if (isGlobalLoadAndAlloc(result)) { + llvm_unreachable("cpasync not supported yet"); + } else if (auto alloc = result.getDefiningOp()) { + triton::gpu::createInto(builder, loc, producerPartitions, + stageCluster, alloc.getSrc(), + dataBuf); + staleOps.push_back(alloc); + } else if (auto tensorType = dyn_cast(result.getType())) { + if (auto descOp = result.getDefiningOp()) { + createNVWSDescriptorLoadOp(builder, descOp, dataBuf, producerPartitions, + loc); + producerKind = AsyncOp::TMALoad; + staleOps.push_back(descOp); + } else if (auto loadOp = result.getDefiningOp()) { + llvm_unreachable("cpasync not supported yet"); + } else { + triton::gpu::createInto(builder, loc, producerPartitions, + stageCluster, result, dataBuf); + producerKind = AsyncOp::NONE; + } + } else if (isa(result.getType())) { + auto tensorType = getTensorTypeFromScalar(builder, result); + auto splatOp = triton::gpu::createInto( + builder, loc, producerPartitions, stageCluster, tensorType, result); + triton::gpu::createInto(builder, loc, producerPartitions, + stageCluster, splatOp, dataBuf); + producerKind = AsyncOp::NONE; + } else { + std::string msg = "createArefPut: unsupported produced value type: " + + mlir::debugString(result.getType()); + llvm::report_fatal_error(msg.c_str()); + } + + triton::gpu::createInto( + builder, loc, producerPartitions, stageCluster, aref, + putEnterOp.getToken(), + builder.getArrayAttr(SmallVector{ + AsyncOpAttr::get(aref.getContext(), producerKind)})); + + return staleOps; +}; + +SetVector +getTransitiveConsumers(Operation *op, + SetVector const &consumerPartitions) { + SetVector opConsumers; + auto isMemDesc = [](auto res) { return isa(res.getType()); }; + for (auto &use : op->getUses()) { + if (llvm::count_if(use.getOwner()->getResults(), isMemDesc) > 0) { + // Recurse into consumers of memdesc ops, since the liveness of the + // produced value extends beyond such ops. + auto consumers = + getTransitiveConsumers(use.getOwner(), consumerPartitions); + opConsumers.insert(consumers.begin(), consumers.end()); + } else { + if (getPartitionIds(&use) == consumerPartitions) { + opConsumers.insert(use.getOwner()); + // If an op is defined before an inner loop and used inside, the loop + // itself should be considered as an additional consumer. This is + // necessary for persistent attention, where the load of Q is done + // before the inner loop. + opConsumers.insert( + op->getBlock()->findAncestorOpInBlock(*use.getOwner())); + } + } + } + return opConsumers; +} + +SmallVector +getTransitiveConsumers(const SetVector &results, + SetVector const &consumerPartitions) { + SetVector opSet; + for (auto result : results) { + if (isa(result)) { + for (auto &use : result.getUses()) { + if (getPartitionIds(&use) == consumerPartitions) { + opSet.insert(use.getOwner()); + } + } + } else { + auto consumers = + getTransitiveConsumers(result.getDefiningOp(), consumerPartitions); + opSet.insert(consumers.begin(), consumers.end()); + } + } + return SmallVector{opSet.begin(), opSet.end()}; +} + +SmallVector getConsumerAsyncOpKinds(ArrayRef consumers, + MLIRContext *ctx) { + SetVector kindSet; + for (auto consumer : consumers) { + if (isa(consumer) && consumers.size() > 1) { + // In this case, a getExit is placed after the consumer loop. The + // corresponding async kind attributes should be determined from other + // consumer ops in the loop. + continue; + } + if (isa(consumer)) { + kindSet.insert(AsyncOp::WGMMA); + } else if (isa(consumer)) { + kindSet.insert(AsyncOp::TC5MMA); + } else { + kindSet.insert(AsyncOp::NONE); + } + } + + SmallVector kindAttrs; + for (auto kind : kindSet) { + kindAttrs.push_back(AsyncOpAttr::get(ctx, kind)); + } + + return kindAttrs; +} + +std::pair +getEnterAndExitStageClustersOfUses(const SetVector &producedResults, + std::function filterUse, + scf::ForOp forOp) { + CoarseSchedule coarseSchedule; + if (!forOp || failed(coarseSchedule.deSerialize(forOp)) || + producedResults.empty()) { + return std::make_pair(std::nullopt, std::nullopt); + } + + SmallVector ops; + for (auto res : producedResults) { + if (auto blockArg = dyn_cast(res)) { + // If the producer is a block argument, this means we need to communicate + // iteration arguments from the producer partition in the previous + // iteration to the consumer partition in the current iteration. There + // must be only one produced result in this case. + assert(producedResults.size() == 1); + auto block = blockArg.getOwner(); + auto forOp = cast(block->getParentOp()); + auto opnd = forOp.getYieldedValues()[blockArg.getArgNumber() - 1]; + auto op = opnd.getDefiningOp(); + auto stageCluster = getStageCluster(op); + return std::make_pair(stageCluster, stageCluster); + } + auto op = res.getDefiningOp(); + ops.push_back(op); + } + + auto firstOp = + triton::getFirstUseOfPipelinedOp(ops, forOp, coarseSchedule, filterUse); + auto lastOp = + triton::getLastUseOfPipelinedOp(ops, forOp, coarseSchedule, filterUse); + assert(firstOp && lastOp); + + return std::make_pair(getStageCluster(firstOp), getStageCluster(lastOp)); +} + +void createArefGet(OpBuilder &builder, scf::ForOp loop, ArefCreateOp aref, + const SetVector &results, int consumerPartition, + SmallVector &uses) { + OpBuilder::InsertionGuard g(builder); + // The vector "results" contains either + // 1. One of local_load(desc_load()) or desc_load() + // 2. Both of them + // In the second case, we only need to emit one enter / exit since we know + // that the two results are used by consumers in the same partition. + assert(results.size() == 1 || results.size() == 2); + auto loc = results[0].getLoc(); + + scf::ForOp scheduledLoop; + loop->walk([&](scf::ForOp op) { + if (op->hasAttr(mlir::triton::kScheduledMaxStageAttrName)) { + scheduledLoop = op; + } + }); + + auto filterUse = [&](Operation *user) { + if (hasPartition(user)) { + return llvm::is_contained(getPartitionIds(user), consumerPartition); + } else { + return false; + } + }; + + // Filter results to include only those defined inside the scheduled loop + // (if any). This is done because otherwise the result might not have its + // last use (in either direction) inside the scheduled loop and we will not be + // able to get `stageClusterEnter` and/or `stageClusterExit`. + SetVector resultsInScheduledLoop; + for (Value v : results) { + if (Operation *defOp = v.getDefiningOp()) { + if (scheduledLoop && scheduledLoop->isAncestor(defOp)) + resultsInScheduledLoop.insert(v); + } + } + + auto [stageClusterEnter, stageClusterExit] = + getEnterAndExitStageClustersOfUses(resultsInScheduledLoop, filterUse, + scheduledLoop); + + SetVector consumerPartitions; + consumerPartitions.insert(consumerPartition); + auto arefBufType = cast(aref.getOperand(0).getType()); + Type bufferType = getBufferViewType(arefBufType, /*mutable*/ false); + Type tokenType = builder.getType(); + auto getEnterOp = triton::gpu::createInto( + builder, loc, consumerPartitions, stageClusterEnter, aref, + TypeRange{bufferType}, tokenType); + + auto consumers = getTransitiveConsumers(results, consumerPartitions); + assert(consumers.size() > 0); + auto asyncKinds = getConsumerAsyncOpKinds(consumers, aref.getContext()); + Value dataBuf = getEnterOp.getBuffers()[0]; + Value token = getEnterOp.getToken(); + + Operation *exitInsertPointAfter = nullptr; + + auto replaceUsesWithLocalLoad = [&](Value result, StageCluster stageCluster) { + auto localLoadOp = triton::gpu::createInto( + builder, loc, consumerPartitions, stageCluster, result.getType(), + dataBuf); + + for (auto use : uses) { + if (use->get() == result) { + use->set(localLoadOp.getResult()); + } + } + if (dataBuf.hasOneUse()) { + // If there is only one consumer for dataBuf, it is localLoadOp created + // above, and we hit this code path, the empty barrier can be released + // after local load. + exitInsertPointAfter = localLoadOp; + } + }; + + for (auto result : results) { + if (auto localAlloc = result.getDefiningOp()) { + auto callback = [&](Operation *oldOp, Operation *newOp) { + assert(llvm::is_contained(getPartitionIds(oldOp), consumerPartition)); + setPartition(newOp, consumerPartitions); + }; + replaceUsesAndPropagateType(builder, localAlloc, dataBuf, callback); + } else if (isa(result.getType())) { + replaceUsesWithLocalLoad(result, stageClusterEnter); + } else if (isa(result.getType())) { + auto tensorType = getTensorTypeFromScalar(builder, result); + auto localLoadOp = triton::gpu::createInto( + builder, loc, consumerPartitions, stageClusterEnter, tensorType, + dataBuf); + auto scalar = triton::gpu::createInto( + builder, loc, consumerPartitions, stageClusterEnter, localLoadOp); + for (auto use : uses) { + use->set(scalar); + } + exitInsertPointAfter = localLoadOp; + } else { + std::string msg = "createArefGet: unsupported produced value type: " + + mlir::debugString(result.getType()); + llvm::report_fatal_error(msg.c_str()); + } + } + + if (exitInsertPointAfter == nullptr) { + PostDominanceInfo dom(loop); + exitInsertPointAfter = findNearestCommonPostDominator(consumers, dom); + } + + builder.setInsertionPointAfter(exitInsertPointAfter); + + triton::gpu::createInto(builder, loc, consumerPartitions, + stageClusterExit, aref, token, + builder.getArrayAttr(asyncKinds)); +}; + +Operation *getEarliestUserInBlock(Block *block, ArrayRef uses) { + OpOperand *use = + *llvm::min_element(uses, [block](OpOperand *lhs, OpOperand *rhs) { + auto lhsOwner = block->findAncestorOpInBlock(*lhs->getOwner()); + auto rhsOwner = block->findAncestorOpInBlock(*rhs->getOwner()); + return lhsOwner->isBeforeInBlock(rhsOwner); + }); + return block->findAncestorOpInBlock(*use->getOwner()); +} + +bool insertArefs(OpBuilder &builder, scf::ForOp loop, Block *block, + ProducedValueInfo producedValue) { + // Collect uses of local_alloc(desc_load()) or desc_load() results by each + // partition + DenseMap> resultsPerPartition; + DenseMap> usesPerPartition; + auto processResultUses = [&](Value result) { + for (auto &use : result.getUses()) { + auto user = use.getOwner(); + // if use is outside ttg.ws, it may not have partition ids, skip it + if (!hasPartition(user)) + continue; + auto userPartitions = getPartitionIds(&use); + for (auto id : producedValue.partitions) { + userPartitions.remove(id); + } + for (auto id : userPartitions) { + resultsPerPartition[id].insert(result); + usesPerPartition[id].push_back(&use); + } + } + }; + + processResultUses(producedValue.result); + + if (auto opt = isDescLoadAndAlloc(producedValue.result)) { + // Process the register use as well + auto alloc = opt->first; + processResultUses(alloc.getSrc()); + } + + if (resultsPerPartition.empty()) { + return false; + } + + ArefCreateOp aref; + { + OpBuilder::InsertionGuard g(builder); + auto wsLoop = getOuterWSLoop(loop); + builder.setInsertionPoint(wsLoop); + aref = createAref(builder, producedValue); + } + + auto staleOps = createArefPut(builder, aref, producedValue); + + for (auto [consumerPartition, results] : resultsPerPartition) { + OpBuilder::InsertionGuard g(builder); + auto earliestUser = + getEarliestUserInBlock(block, usesPerPartition[consumerPartition]); + builder.setInsertionPoint(earliestUser); + createArefGet(builder, loop, aref, results, consumerPartition, + usesPerPartition[consumerPartition]); + } + + for (auto op : staleOps) { + op->erase(); + } + + return true; +} + +} // namespace + +class NVWSArefInsertion + : public triton::impl::NVWSInsertArefBase { +public: + void runOnFunction(triton::FuncOp func) { + SmallVector loops; + func.walk([&](scf::ForOp loop) { + auto func = loop->getParentOfType(); + if (loop->hasAttr(triton::kWarpSpecializeAttrName) && hasPartition(loop)) + loops.push_back(loop); + }); + + for (scf::ForOp loop : loops) { + loop.walk([&](scf::ForOp forOp) { + // Communicate tensor arguments in iter_args from producer partition in + // current iteration to consumer partition in previous iteration or + // initial value + for (auto arg : forOp.getRegionIterArgs()) { + if (isa(arg.getType())) { + auto producerPartition = + getPartitionOutputs(forOp)[arg.getArgNumber() - 1]; + ProducedValueInfo producedValue{producerPartition, arg}; + OpBuilder builder(forOp); + builder.setInsertionPointToStart(forOp.getBody()); + insertArefs(builder, loop, forOp.getBody(), producedValue); + } + } + }); + + // To handle cases where desc_load result in registers is used as is in + // addition to being consumed by local_alloc op, we process + // local_alloc(desc_load()) first, followed by remaining register uses of + // desc_load results. + SmallVector memoryOps; + loop.walk([&](Operation *op) { + if (op->getNumResults() > 0 && + (isDescLoadAndAlloc(op->getResult(0)) || + isa(op))) { + memoryOps.push_back(op); + } + }); + + for (auto op : memoryOps) { + auto producedValues = getProducedValues(op, loop.getBody()); + for (auto producedValue : producedValues) { + OpBuilder builder(op); + insertArefs(builder, loop, op->getBlock(), producedValue); + } + } + + // handle non-tmem ops in the loop, including uses of desc_load results. + loop.walk([&](Operation *op) { + if (op == loop || isa(op)) { + return WalkResult::advance(); + } + auto producedValues = getProducedValues(op, loop.getBody()); + for (auto producedValue : producedValues) { + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + insertArefs(builder, loop, op->getBlock(), producedValue); + } + return WalkResult::advance(); + }); + } + } + + void runOnOperation() override { + getOperation().walk([&](triton::FuncOp func) { runOnFunction(func); }); + } +}; + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp new file mode 100644 index 0000000000..64053c0ac1 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp @@ -0,0 +1,930 @@ +#include "Utilities.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSINSERTTMEMAREF +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "nvws-insert-tmem-aref" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +using namespace mlir; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; +using namespace triton::nvws; + +int getWsTag(Operation *op) { + while (op && !hasWarpSpecializeTag(op)) { + op = op->getParentOfType(); + } + assert(op); + return *getWarpSpecializeTag(op); +} + +using PartitionId = std::pair; +std::optional getPartitionId(Operation *op, int pos = 0) { + if (!hasPartition(op)) + return std::nullopt; + auto partitionIds = getPartitionIds(op); + if (op->getNumRegions() > 0) { + partitionIds = getPartitionOutputs(op)[pos]; + } + assert(partitionIds.size() == 1); + return std::make_pair(*partitionIds.begin(), getWsTag(op)); +} + +struct TmemAccessDag { + struct Node { + // For now we assume there is only one use of generated async tmem token + std::unique_ptr user; + SmallVector> subDags; + Node(Operation *op, OpOperand *tokOperand, + std::optional partitionId, Node *parent) + : op(op), tokOperand(tokOperand), partitionId(partitionId), + parent(parent), parentDag(nullptr) {} + + // ------------------------------------------------------------------------ + + Operation *op; + OpOperand *tokOperand; + Node *parent; + Node *parentDag; + std::optional tokPos; + std::optional partitionId; + }; + + TmemAccessDag(std::unique_ptr dag) : dag(std::move(dag)) {} + + Node *getRootNode() { return dag.get(); } + TMEMAllocOp getAllocOp() { return cast(dag->op); } + + Value addIfOp(Value tok, Node *node) { + SmallVector uses; + for (auto &use : tok.getUses()) + uses.push_back(&use); + assert(uses.size() == 2 && "expecting two uses of a token"); + auto useThen = uses[0]; + auto useElse = uses[1]; + + auto ifOp = cast(useThen->getOwner()->getParentOp()); + node->user.reset(new Node(ifOp, nullptr, {}, node)); + auto ifOpNode = node->user.get(); + + if (ifOp.thenBlock() != useThen->getOwner()->getBlock()) + std::swap(useThen, useElse); + assert(ifOp.thenBlock() == useThen->getOwner()->getBlock()); + assert(ifOp.elseBlock() == useElse->getOwner()->getBlock()); + + // Create access DAGs for then/else blocks. + auto thenDag = + std::make_unique(nullptr, nullptr, std::nullopt, nullptr); + auto elseDag = + std::make_unique(nullptr, nullptr, std::nullopt, nullptr); + auto thenTok = addOp(*useThen, thenDag.get()); + auto elseTok = addOp(*useElse, elseDag.get()); + + auto tokPos = + *findValuePosInRange(ifOp.thenYield()->getOperands(), thenTok); + ifOpNode->partitionId = getPartitionId(ifOp, tokPos); + + // find final node in then-branch and assign yieldOp as its user + // XXX: improve representation later, but for now the user's parentDag + // points to the first op in the branch, because we will need to get + // stageCluser information later in aref insertion as ifOps don't carry + // partition assignment to their results like nvws-branch + Node *finalThenNode = thenDag.get(); + while (finalThenNode->user) + finalThenNode = finalThenNode->user.get(); + auto thenYieldOp = ifOp.thenYield(); + finalThenNode->user = + std::make_unique(thenYieldOp, &thenYieldOp->getOpOperand(tokPos), + ifOpNode->partitionId, finalThenNode); + finalThenNode->user->parentDag = thenDag->user.get(); + + // do the same with else-branch + Node *finalElseNode = elseDag.get(); + while (finalElseNode->user) + finalElseNode = finalElseNode->user.get(); + auto elseYieldOp = ifOp.elseYield(); + finalElseNode->user = + std::make_unique(elseYieldOp, &elseYieldOp->getOpOperand(tokPos), + ifOpNode->partitionId, finalElseNode); + finalElseNode->user->parentDag = elseDag->user.get(); + + // the parent of the first op in the branch is null, but parent dag points + // to original ifOp + thenDag->user->parent = nullptr; + elseDag->user->parent = nullptr; + thenDag->user->parentDag = ifOpNode; + elseDag->user->parentDag = ifOpNode; + + ifOpNode->subDags.push_back(std::move(thenDag->user)); + ifOpNode->subDags.push_back(std::move(elseDag->user)); + + ifOpNode->tokPos = tokPos; + + auto newTok = ifOp.getResult(tokPos); + assert(newTok.hasOneUse()); + return addOp(*newTok.getUses().begin(), ifOpNode); + } + + Value addForOp(OpOperand &tokOperand, Node *forOpNode) { + auto forOp = cast(tokOperand.getOwner()); + auto tokPos = tokOperand.getOperandNumber() - 3; + auto tokDefOp = forOp.getYieldedValues()[tokPos].getDefiningOp(); + assert(tokDefOp && "expecting a token definition op"); + + // Create access node for the for-loop body. The first op is nullptr, + // but it has partitionIdx, indicating which partition owns the Tmem when + // entering the region + auto subDag = + std::make_unique(nullptr, nullptr, std::nullopt, nullptr); + auto tokArg = forOp.getRegionIterArg(tokPos); + assert(tokArg.hasOneUse()); + addOp(*tokArg.getUses().begin(), subDag.get()); + forOpNode->partitionId = getPartitionId(forOp, tokPos); + + // finalNode keep track of partition ownership transfer ownership when + // before exiting the loop-body or re-entering loop body + // same as in IfOp then/else branches + Node *finalNode = subDag->user.get(); + while (finalNode->user) + finalNode = finalNode->user.get(); + auto yieldOp = forOp.getBody()->getTerminator(); + finalNode->user = + std::make_unique(yieldOp, &yieldOp->getOpOperand(tokPos), + forOpNode->partitionId, finalNode); + finalNode->user->parentDag = subDag->user.get(); + forOpNode->tokPos = tokPos; + + // subDag->user->parentDag = subDag->user.get(); + subDag->user->parent = nullptr; + subDag->user->parentDag = forOpNode; + + forOpNode->subDags.push_back(std::move(subDag->user)); + return forOp.getResult(tokPos); + } + + Value addOp(OpOperand &tokOperand, Node *node) { + if (isa(tokOperand.getOwner())) + return tokOperand.get(); // return token back to the caller + + auto op = tokOperand.getOwner(); + std::optional partitionId; + // tmem owning partition for if & for ops are inferred from their regions + if (op->getNumRegions() == 0) + partitionId = getPartitionId(op); + node->user.reset(new Node(op, &tokOperand, partitionId, node)); + auto newNode = node->user.get(); + Value newTok; + + if (auto tmemLoad = dyn_cast(op)) { + newTok = tmemLoad.getToken(); + } else if (auto tmemStore = dyn_cast(op)) { + newTok = tmemStore.getToken(); + } else if (auto mmav5 = dyn_cast(op)) { + newTok = mmav5.getToken(); + } else if (auto forOp = dyn_cast(op)) { + newTok = addForOp(tokOperand, newNode); + } else { + llvm_unreachable("unsupported user"); + } + + if (newTok.use_empty()) + return newTok; + + if (newTok.hasOneUse()) { + auto &use = *newTok.getUses().begin(); + return addOp(use, newNode); + } + + // Multiple uses of token are expected only in IfOp: one in then and one in + // else branches. + return addIfOp(newTok, newNode); + } + + static TmemAccessDag build(TMEMAllocOp allocOp) { + std::optional partitionId; + if (allocOp.getSrc()) { + partitionId = getPartitionId(allocOp); + } + TmemAccessDag accessDag( + std::make_unique(allocOp, nullptr, partitionId, nullptr)); + + if (allocOp.getSrc() && !allocOp.getToken()) { + // Handle tmem_alloc with src operand specially. When a src operand is + // present, no async tokens are generated, we can't traverse IR, + // and we directly add the single user operation to the access DAG. + assert(allocOp->hasOneUse()); + auto user = *allocOp->getUsers().begin(); + accessDag.getRootNode()->user.reset(new Node{ + user, nullptr, getPartitionId(user), accessDag.getRootNode()}); + } else { + auto tok = allocOp.getToken(); + assert(tok && tok.hasOneUse()); + auto &tokUse = *tok.getUses().begin(); + accessDag.addOp(tokUse, accessDag.getRootNode()); + } + return accessDag; + } + + void collectPartitions( + Node *node, bool &hasRootPartition, + SmallVector> &partitions) { + if (node->partitionId) { + partitions.push_back(std::make_pair(*node->partitionId, node->op)); + } else { + // root partition is considered a real owner only if there are already + // other partitions owning tmem + hasRootPartition = !partitions.empty(); + } + for (auto &subDag : node->subDags) { + if (subDag) { + collectPartitions(subDag.get(), hasRootPartition, partitions); + } + } + if (node->user) { + collectPartitions(node->user.get(), hasRootPartition, partitions); + } + }; + + std::pair>> + collectPartitionsVec() { + SmallVector> partitions; + bool hasRootPartition = false; + auto node = getRootNode(); + auto allocOp = getAllocOp(); + if (allocOp.getSrc() && node->partitionId) + partitions.push_back(std::make_pair(*node->partitionId, node->op)); + collectPartitions(getRootNode()->user.get(), hasRootPartition, partitions); + return {hasRootPartition, partitions}; + } + + std::pair> collectPartitionsSet() { + auto [hasRootPartition, partitions] = collectPartitionsVec(); + std::set partitionSet; + for (auto [partition, _] : partitions) { + partitionSet.insert(partition); + } + return {hasRootPartition, partitionSet}; + } + + void printNode(Node *node, int indent, llvm::raw_ostream &os) { + if (!node) + return; + for (int i = 0; i < indent; i++) { + os << " "; + } + std::set partitions; + os << "|- [" << node->op << "]"; + bool hasRootPartition = false; + if (node->partitionId) + partitions.insert(*node->partitionId); + else + hasRootPartition = true; + if (node->op) { + os << node->op->getName().getStringRef() << " "; + if (auto tmemAlloc = dyn_cast(node->op)) { + if (tmemAlloc.getSrc()) { + os << " %src "; + } else { + std::tie(hasRootPartition, partitions) = collectPartitionsSet(); + } + } + os << " "; + } + os << "[" << (hasRootPartition ? "root" : ""); + for (auto partition : partitions) { + auto [id, tag] = partition; + os << " @" << tag << "." << id << " "; + } + os << "]"; + os << " prev[" << (node->parent ? node->parent->op : nullptr) << "]"; + os << "\n"; + for (auto &subDag : node->subDags) { + for (int i = 0; i < indent + 4; i++) + os << " "; + os << "|- subDag\n"; + if (subDag) + printNode(subDag.get(), indent + 8, os); + } + if (node->user) { + printNode(node->user.get(), indent, os); + } + }; + void printDag(llvm::raw_ostream &os) { + os << "TMEMDAG\n"; + printNode(dag.get(), 2, os); + os << "\n"; + } + + // -------------------------------------------------------------------------- + + std::unique_ptr dag; +}; + +void assignStage(OpBuilder &b, Operation *op, StageCluster stageCluster) { + if (stageCluster) { + op->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(stageCluster->first)); + op->setAttr(kLoopClusterAttrName, + b.getI32IntegerAttr(stageCluster->second)); + } +} + +template +OpT createInto( + OpBuilder &b, Location loc, + std::pair, StageCluster> partitionIdStageCluster, + Args &&...args) { + std::optional> partitionIds = SetVector(); + std::optional wsTag; + if (partitionIdStageCluster.first) { + auto [id, tag] = *partitionIdStageCluster.first; + wsTag = tag; + partitionIds->insert(id); + } else { + partitionIds = std::nullopt; + } + auto op = triton::gpu::createInto(b, loc, partitionIds, + partitionIdStageCluster.second, + std::forward(args)...); + if (wsTag) { + auto forOp = op->template getParentOfType(); + while (forOp && !hasWarpSpecializeTag(forOp)) { + forOp = forOp->template getParentOfType(); + } + // only set wsTag if op is outside tt.ws loop + if (!forOp) { + setWarpSpecializeTag(op, *wsTag); + } + } + return op; +} + +struct TMEMAref { + enum Kind { PUT, GET }; + + TMEMAref(Value aref, Value origBuffer, Value replToken) + : aref(aref), origBuffer(origBuffer), replToken(replToken), kind(PUT) {} + + void acquire(OpBuilder &b, Location loc, + std::pair, StageCluster> + paritionIdStageCluster) { + auto arefBufType = + cast(aref.getDefiningOp()->getOperand(0).getType()); + Type dataBufType = getArefViewBufferType(arefBufType); + SmallVector buffers{dataBufType}; + SmallVector tokens{b.getType()}; + if (kind == PUT) { + auto op = + createInto(b, loc, paritionIdStageCluster, aref, + buffers, b.getType()); + token = op.getToken(); + } else { + auto op = + createInto(b, loc, paritionIdStageCluster, aref, + buffers, b.getType()); + token = op.getToken(); + } + partitionId = paritionIdStageCluster.first; + if (partitionId) + stageClusters[*partitionId] = paritionIdStageCluster.second; + buffer = {}; + } + void release(OpBuilder &b, Location loc) { + assert(asyncOp[partitionId]); + StageCluster stageCluster; + if (partitionId) + stageCluster = stageClusters[*partitionId]; + if (kind == PUT) { + createInto( + b, loc, {partitionId, stageCluster}, aref, token, + b.getArrayAttr(SmallVector{ + AsyncOpAttr::get(b.getContext(), *asyncOp[partitionId])})); + kind = GET; + } else { + createInto( + b, loc, {partitionId, stageCluster}, aref, token, + b.getArrayAttr(SmallVector{ + AsyncOpAttr::get(b.getContext(), *asyncOp[partitionId])})); + kind = PUT; + } + } + Value getBuffer(OpBuilder &b, std::optional partitionId, + Operation *op) { + if (!buffer) { + auto stageCluster = getStageCluster(op); + auto arefBufType = + cast(aref.getDefiningOp()->getOperand(0).getType()); + Type dataBufType = getArefViewBufferType(arefBufType); + SmallVector buffers{dataBufType}; + auto bufferOp = createInto( + b, op->getLoc(), {partitionId, stageCluster}, aref, buffers, token); + + buffer = bufferOp.getBuffers()[0]; + } + return buffer; + } + + // -------------------------------------------------------------------------- + + Value origBuffer; + Value aref; + Value replToken; + + Value buffer; + Value token; + Kind kind; + std::optional partitionId; + llvm::MapVector, std::optional> asyncOp; + DenseMap stageClusters; +}; + +TmemAccessDag::Node * +insertTmemArefImpl(TmemAccessDag::Node *node, + std::optional curPartitionId, TMEMAref &state) { + // When entering a warp-specialized loop, curPartitionId is std::nullopt. + // We skip ownership changes here since there's an implicit synchronization + // barrier when entering the ws-loop that handles the transition safely. + if (curPartitionId && node->partitionId != curPartitionId) { + OpBuilder b(node->op); + Operation *prevOp = nullptr; + if (node->parent) { + // release right after the last op which owns the tmem + prevOp = node->parent->op; + b.setInsertionPointAfter(prevOp); + } else { + // if we are inside if-stmt or for-stmt subdag and need to change + // ownerhip, release at the top of the block + // the parentDag op would be if-stmt or for-stmt + prevOp = node->parentDag->op; + b.setInsertionPointToStart(node->op->getBlock()); + } + state.release(b, prevOp->getLoc()); + + // acquire right before op that acquires ownership of tmem + auto curOp = node->op; + auto partitionId = node->partitionId; + b.setInsertionPoint(curOp); + + if (isa(curOp)) { + // in yieldOp we overload parentDag as the first op in the current subDag + // so we use its stageCluster to insert acquire + curOp = node->parentDag->op; + } + auto stageCluster = getStageCluster(curOp); + // if stage-cluster is empty, use the stage-cluster used from the last op + // that acquired ownership of tmem in a partition + if (!stageCluster && partitionId) + stageCluster = state.stageClusters[*partitionId]; + state.acquire(b, curOp->getLoc(), {partitionId, stageCluster}); + } + + for (auto &subDag : node->subDags) { + auto subdagState = state; + if (auto forOp = dyn_cast(node->op)) { + // forOp may have token operand, if so, we need to update the token and + // and reset buffer + if (node->tokOperand) { + subdagState.token = + forOp.getRegionIterArg(node->tokOperand->getOperandNumber() - 3); + subdagState.buffer = {}; + } + } + insertTmemArefImpl(subDag.get(), node->partitionId, subdagState); + + // subDag may change asyncOp value, update it after inserting arefs + state.asyncOp = subdagState.asyncOp; + // store subdag state partitoinId + state.partitionId = subdagState.partitionId; + } + + if (isa(node->op)) { + state.asyncOp[node->partitionId] = AsyncOp::TC5MMA; + } else if (isa(node->op)) { + state.asyncOp[node->partitionId] = AsyncOp::NONE; + } + + OpBuilder b(node->op); + if (auto tmemLoadOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); + tmemLoadOp.getSrcMutable().assign( + state.getBuffer(b, node->partitionId, node->op)); + tmemLoadOp.getDepMutable().clear(); + tmemLoadOp.getToken().replaceAllUsesWith(state.replToken); + } else if (auto tmemStoreOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); + tmemStoreOp.getDstMutable().assign( + state.getBuffer(b, node->partitionId, node->op)); + tmemStoreOp.getDepMutable().clear(); + tmemStoreOp.getToken().replaceAllUsesWith(state.replToken); + } else if (auto mmaOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); + if (mmaOp.getAccumulator() == state.origBuffer) { + mmaOp.getAccDepMutable().clear(); + mmaOp.getToken().replaceAllUsesWith(state.replToken); + } + for (auto &opnd : mmaOp->getOpOperands()) { + if (opnd.get() == state.origBuffer) + opnd.set(state.getBuffer(b, node->partitionId, node->op)); + } + } else if (auto yieldOp = dyn_cast(node->op)) { + yieldOp.setOperand(node->tokOperand->getOperandNumber(), state.token); + } else if (isa(node->op)) { + if (node->tokPos) { + // forOp/if may return token, if so, update state token, and reset buffer + if (isa(node->op)) + node->op->setOperand(node->tokOperand->getOperandNumber(), state.token); + state.token = node->op->getResult(*node->tokPos); + state.buffer = {}; + } + } else { + llvm_unreachable("unsupported tmem op"); + } + + if (node->user) + return insertTmemArefImpl(node->user.get(), node->partitionId, state); + return node; +} + +bool canDoubleBufferAcc(MMAv5OpInterface mmaOp, int numTmemBlocks) { + auto tmemDesc = mmaOp.getAccumulator().getType(); + auto blockM = tmemDesc.getShape()[0]; + auto blockN = tmemDesc.getShape()[1]; + constexpr int numTMEMColumns = 512; + constexpr int numTMEMRows = 128; + if (numTmemBlocks + (blockM * blockN * 2) > numTMEMRows * numTMEMColumns) { + return false; + } + if (isa(mmaOp) && blockN == 256) { + return false; + } + return true; +}; + +bool hasProducerConsumerPartitioning(TmemAccessDag &accessDag) { + // TMEM partitioning follows a producer-consumer pattern if it has this + // structure: + // + // |alloc + // |-- ops + // loop (tt.ws) + // |---- producer @A + // |---- consumer @B + // |---- producer @A + // + // We have root operations, then enter a warp-specialized loop where: + // - First, partition A owns TMEM and performs producer operations + // - Then, partition B owns TMEM and performs consumer operations + // - Possibly, partition A owns TMEM and performs producer operations + // - Loop repeats with partition A yielding + // + // Here is an example where the producer-consumer pattern is not present: + // |alloc + // |store + // |for (tt.ws) + // | |store @A + // | |for + // | | mma @B + // | |load @A + // The partitions @A & @B are both producers. + // + // Compare to the following, where we change ownership of TMEM where partition + // B is the producer and partition A is the consumer: + // |alloc + // |store + // |for (tt.ws) + // | |store @B + // | |for + // | | mma @B + // | |load @A + // Here, we may double-buffer the accumulator. + // + // This is a necessary (but not sufficient) condition for enabling TMEM + // multi-buffering with arefs. Additional validation will verify sufficient + // conditions for multi-buffering. + + auto [hasRootPartition, partitions] = accessDag.collectPartitionsVec(); + bool expectProducer = true; + int changeGroup = 0; + bool valid = true; + + // Count partition transitions: producer-consumer pattern has exactly two + // transitions (A->B followed by B->A), where 'A' is producer and 'B' is + // consumer. More than two transitions (e.g., A-A-B-B-A-A-B-B-A-A) indicate a + // more complex pattern that doesn't fit the producer-consumer model. + for (size_t i = 0; i < partitions.size() - 1; ++i) { + auto op = partitions[i].second; + if (isa(op)) { + valid = valid && (expectProducer ? isa(op) + : isa(op)); + } + if (partitions[i].first != partitions[i + 1].first) { + expectProducer = !expectProducer; + ++changeGroup; + } + } + valid = valid && changeGroup == 2; + + return valid; +} + +int insertTmemAref(TmemAccessDag &accessDag, int numTmemBlocks) { + auto rootNode = accessDag.getRootNode(); + auto allocOp = cast(rootNode->op); + + auto isMultiStaged = hasProducerConsumerPartitioning(accessDag); + int numTmemBlock = 0; + if (isMultiStaged) { + for (auto user : allocOp.getResult().getUsers()) { + if (auto mmaOp = dyn_cast(user)) { + if (auto loop = dyn_cast(user->getParentOp())) { + auto wsLoop = getOuterWSLoop(loop); + // Determine if the MMA accumulator can be multibuffered. + bool accIsMultiBuffered = + // MMAs in subsequent iterations can be overlapped. + !nvidia_gpu::hasAccReadModifyWrite(mmaOp, loop) && + // The accumulator is reset at some point, thus allowing + // multibuffering. + isAccMultibufferingPossible(mmaOp, loop) && + // The user didn't disable it with a flag. + !getDisallowAccMultiBuffer(wsLoop) && + canDoubleBufferAcc(mmaOp, numTmemBlocks); + isMultiStaged = isMultiStaged && accIsMultiBuffered; + } + } + } + } + auto numStages = 1 + isMultiStaged; + + // update numTmemBlocks for the number of TMEM blocks used by the aref buffer + auto allocShape = allocOp.getType().getShape(); + numTmemBlocks += allocShape[0] * allocShape[1] * numStages; + auto arefBufType = + getArefMultiBufferedType(allocOp.getResult().getType(), numStages); + OpBuilder b(allocOp); + + // alloc can be inside ws-loop, we need to find the entry point for ws-loop + auto outerWsLoop = allocOp->getParentOfType(); + while (outerWsLoop && !outerWsLoop->hasAttr(triton::kWarpSpecializeAttrName)) + outerWsLoop = outerWsLoop->getParentOfType(); + if (outerWsLoop) + b.setInsertionPoint(outerWsLoop); + auto arefAlloc = + cast(createAlloc(b, allocOp.getLoc(), arefBufType, Value())); + auto arefOp = createArefCreateOp(b, {arefBufType}, {arefAlloc->getResult(0)}, + allocOp.getLoc()); + + auto stageCluster = getStageCluster(allocOp); + auto partitionId = accessDag.getRootNode()->partitionId; + if (!allocOp.getSrc() && outerWsLoop) { + // if tmem_alloc inside ws-loop, the first owner is that of the first user + partitionId = accessDag.getRootNode()->user->partitionId; + } + + TMEMAref state( + arefOp, allocOp.getResult(), + ub::PoisonOp::create(b, allocOp.getLoc(), b.getType())); + b.setInsertionPoint(allocOp); + state.acquire(b, allocOp.getLoc(), {partitionId, stageCluster}); + + // If initial acquire is in root partition (no partition annotation), the + // release must be in the partition of the first owner that has a partition + // annotation. Find that partition and update state.partitionId accordingly. + if (!state.partitionId) { + auto node = rootNode->user.get(); + do { + state.partitionId = node->partitionId; + node = node->user.get(); + } while (node && !state.partitionId); + } + + if (auto src = allocOp.getSrc()) { + auto buffer = state.getBuffer(b, partitionId, allocOp); + state.asyncOp[partitionId] = AsyncOp::NONE; + auto vTrue = createInto( + b, allocOp.getLoc(), {partitionId, stageCluster}, true, 1); + createInto(b, allocOp.getLoc(), {partitionId, stageCluster}, + Type(), buffer, Value(), src, vTrue); + } else { + // allocOp w/o src, assume the ownership of tmem belongs to first user + // partitionId = accessDag.getRootNode()->user->partitionId; + } + + auto node = insertTmemArefImpl(rootNode->user.get(), partitionId, state); + + if (outerWsLoop) { + // aref is only used inside ws-loop, so we use the last op to insert + // matching exit + b.setInsertionPointAfter(node->op); + } else { + // aref is used outside ws-loop, find the last point in the same block as + // create op to have matching exit + auto op1 = arefOp->getBlock()->findAncestorOpInBlock(*node->op); + if (auto id = node->partitionId) + state.stageClusters[*id] = {}; + b.setInsertionPointAfter(op1); + } + state.release(b, node->op->getLoc()); + + if (state.kind == TMEMAref::GET) { + // When the state ends up in a GET operation, we need to acquire and release + // the corresponding partition to prevent deadlocks. This is necessary + // because if we're inside an outer loop, re-entering the loop without + // posting a matching GET operation for the PUT would cause the dead-lock. + auto [hasRootPartition, partitions] = accessDag.collectPartitionsSet(); + std::optional otherPartitionId; + // since we only have two partition, we just pick the other partition for + // get + for (auto partitionId : partitions) { + if (partitionId != state.partitionId) { + otherPartitionId = partitionId; + break; + } + } + state.acquire(b, node->op->getLoc(), {otherPartitionId, {}}); + state.release(b, node->op->getLoc()); + } + + return numTmemBlocks; +} + +void workaroundForLoopScheduler(triton::FuncOp funcOp) { + SmallVector ifs; + funcOp.walk([&](scf::IfOp ifOp) { + auto firstOp = &*ifOp.thenBlock()->begin(); + auto lastOp = ifOp.thenBlock()->getTerminator()->getPrevNode(); + if (isa(firstOp) && isa(lastOp)) { + ifs.push_back(ifOp); + } + }); + + // Transform if-statements that contain aref put.exit/put.enter pairs to work + // around loop scheduler limitations. The transformation splits a single if-op + // with token-producing operations into three separate if-ops to ensure proper + // scheduling and token handling. + // + // Original pattern: + // %results, %token, %more = scf.if %condition { + // aref.put.exit // Release tensor memory + // // User computation + // %new_token = aref.put.enter // Acquire tensor memory + // scf.yield %values, %new_token, %other_values + // } else { + // scf.yield %alt_values, %old_token, %alt_other_values + // } + // ... use %token + // + // Transformed pattern: + // scf.if %condition { + // aref.put.exit // Separate exit operation + // } { .. loop.stage = 1, ttg.partition = {1}, ttg.partition.outputs = [] } + // %results, %poison_tok, %more = scf.if %condition { + // // Main computation without token ops + // scf.yield %values, %poison_tok, %other_values + // } else { + // scf.yield %alt_values, %poison_tok, %alt_other_values + // } {.. ttg.partition = {0}, ttg.partition.outputs = [{0}, {0}, {0}, ..]} + // %token = scf.if %condition { + // %new_token = aref.put.enter // Separate enter operation + // scf.yield %new_token + // } else { + // scf.yield %old_token + // } { .. loop.stage = 1, ttg.partition = {1}, ttg.partition.outputs = + // [{1}]} + // ... use %token + + for (auto ifOp : ifs) { + ImplicitLocOpBuilder b(ifOp.getLoc(), ifOp); + + // move putExitOp + b.setInsertionPoint(ifOp); + auto exitIf = + scf::IfOp::create(b, SmallVector{}, ifOp.getCondition(), false); + auto putExitOp = cast(*ifOp.thenBlock()->begin()); + putExitOp->moveBefore(exitIf.thenBlock(), exitIf.thenBlock()->begin()); + + // move putEnterOp + b.setInsertionPointAfter(ifOp); + auto enterIf = + scf::IfOp::create(b, SmallVector{b.getType()}, + ifOp.getCondition(), true); + auto putEnterOp = + cast(ifOp.thenBlock()->getTerminator()->getPrevNode()); + putEnterOp->moveBefore(enterIf.thenBlock(), enterIf.thenBlock()->begin()); + + // replace token uses + auto tok = putEnterOp.getToken(); + auto pos = *findValuePosInRange(ifOp.thenYield()->getOperands(), tok); + ifOp.getResult(pos).replaceAllUsesWith(enterIf.getResult(0)); + + // insert yield-ops inside enterIf + b.setInsertionPointToEnd(enterIf.thenBlock()); + scf::YieldOp::create(b, tok); + b.setInsertionPointToEnd(enterIf.elseBlock()); + scf::YieldOp::create(b, ifOp.elseYield().getOperand(pos)); + + // invalid tokens in main ifOp + b.setInsertionPoint(ifOp); + auto poisonToken = ub::PoisonOp::create(b, b.getType()); + ifOp.thenYield().setOperand(pos, poisonToken); + ifOp.elseYield().setOperand(pos, poisonToken); + + // patch loop.stage=1 + enterIf->setAttrs(ifOp->getAttrs()); + exitIf->setAttrs(ifOp->getAttrs()); + assignStage(b, enterIf, getStageCluster(putEnterOp)); + assignStage(b, exitIf, getStageCluster(putExitOp)); + + SetVector enterExitIds, middleIds; + enterExitIds.insert(1); + middleIds.insert(0); + setPartition(enterIf, enterExitIds); + setPartition(exitIf, enterExitIds); + setPartition(ifOp, middleIds); + + SetVector p0array, p1array; + p0array.insert(0); + p1array.insert(1); + setPartitionOutputs(exitIf, {}); + setPartitionOutputs(enterIf, {p1array}); + SmallVector> outputs(ifOp->getNumResults(), p0array); + setPartitionOutputs(ifOp, outputs); + } +} + +LogicalResult runOnFunction(triton::FuncOp funcOp) { + // Skip this function if there is no warp specialized loop. + auto walkResult = funcOp.walk([&](scf::ForOp forOp) { + if (forOp->hasAttr(kWarpSpecializeAttrName)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (!walkResult.wasInterrupted()) + return success(); + + SmallVector tmemDags; + funcOp.walk([&](TMEMAllocOp allocOp) { + tmemDags.push_back(TmemAccessDag::build(allocOp)); + }); + + int numTmemBlocks = 0; + for (auto &accessDag : tmemDags) { + LLVM_DEBUG({ accessDag.printDag(llvm::dbgs()); }); + auto [hasRootPartition, partitions] = accessDag.collectPartitionsSet(); + assert(partitions.size() <= 2 && "expecting at most 2 partitions"); + auto totalOwners = hasRootPartition + partitions.size(); + if (totalOwners > 1) { + numTmemBlocks = insertTmemAref(accessDag, numTmemBlocks); + } + } + + workaroundForLoopScheduler(funcOp); + + return success(); +} + +} // namespace + +class NVWSTmemArefInsertion + : public triton::impl::NVWSInsertTmemArefBase { +public: + void runOnOperation() override { + getOperation().walk([&](triton::FuncOp funcOp) { + if (failed(runOnFunction(funcOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + } +}; + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerAref.cpp new file mode 100644 index 0000000000..c3f50caf5e --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -0,0 +1,961 @@ +/* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "Utilities.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; +using namespace mlir::triton::nvws; + +#define DEBUG_TYPE "nvws-lower-aref" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSLOWERAREF +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +namespace { + +// ---------------------------------------------------------------------------- + +struct PartitionWsTagIds { + std::optional wsTag; + SetVector partitionIds; +}; +std::optional getPartitionWsTagIds(Operation *op) { + std::optional partitionWsTagIds; + if (hasPartition(op)) { + partitionWsTagIds = + PartitionWsTagIds{std::nullopt, triton::gpu::getPartitionIds(op)}; + if (auto wsTag = getWarpSpecializeTag(op)) { + partitionWsTagIds->wsTag = *wsTag; + } + } + return partitionWsTagIds; +} + +using PartitionSet = SetVector; +void assignStageCluster(Operation *op, + std::optional partitionWsTagIds, + StageCluster stageCluster, OpBuilder &builder) { + if (partitionWsTagIds) { + setPartition(op, partitionWsTagIds->partitionIds); + if (auto wsTag = partitionWsTagIds->wsTag) { + setWarpSpecializeTag(op, *wsTag); + } + setStageCluster(builder, op, stageCluster); + } +} + +bool isOperandPipelineable(Value v, scf::ForOp forOp) { + auto isPipelineable = [](Operation *op) { + return isa(op); + }; + + Operation *foundDef = nullptr; + return triton::nvidia_gpu::isOperandPipelineableBase(v, forOp, foundDef, + isPipelineable); +} + +void setIsAsync(triton::nvidia_gpu::MMAv5OpInterface mmaOp, + unsigned defaultNumStages) { + bool isAsync = true; + auto forOp = mmaOp->getParentOfType(); + if (!forOp) + return; + + unsigned numStages = getNumStagesOrDefault(forOp, defaultNumStages); + if (numStages <= 1) + return; + + if (auto scaledOp = dyn_cast( + mmaOp.getOperation())) { + if (!triton::nvidia_gpu::areScalesPipelineable(scaledOp, forOp)) { + isAsync = false; + } + if (!isOperandPipelineable(scaledOp.getAScale(), forOp) || + !isOperandPipelineable(scaledOp.getBScale(), forOp)) { + isAsync = false; + } + } + mmaOp.setIsAsync(isAsync); +} + +struct ArefValue { + Value emptyMbars; + Value fullMbars; + int depth; + SmallVector buffers; +}; + +Value getEmptyBarrier(PatternRewriter &rewriter, Location loc, ArefValue aref, + Value stage, + std::optional partitionWsTagIds, + StageCluster stageCluster) { + auto barrier = createSingleBufferView(rewriter, aref.emptyMbars, stage); + assignStageCluster(barrier.getDefiningOp(), partitionWsTagIds, stageCluster, + rewriter); + return barrier; +} + +Value getFullBarrier(PatternRewriter &rewriter, Location loc, ArefValue aref, + Value stage, + std::optional partitionWsTagIds, + StageCluster stageCluster) { + auto barrier = createSingleBufferView(rewriter, aref.fullMbars, stage); + assignStageCluster(barrier.getDefiningOp(), partitionWsTagIds, stageCluster, + rewriter); + return barrier; +} + +struct BarrierCount { + int producerPendingCount{0}; + int consumerPendingCount{0}; +}; + +SmallVector castAsyncOpAttrs(ArrayAttr opAttrs) { + SmallVector kinds; + for (auto asyncKind : opAttrs) { + kinds.push_back(cast(asyncKind).getValue()); + } + return kinds; +} + +BarrierCount getArrivalCount(ArefCreateOp op) { + SetVector producerGroups, consumerGroups; + BarrierCount count; + + for (auto user : op->getUsers()) { + if (!hasPartition(user)) + continue; + auto partitionIds = getPartitionIds(user); + + assert(partitionIds.size() == 1); + + if (auto putExitOp = dyn_cast(user)) { + if (producerGroups.count(partitionIds.front())) { + continue; + } + producerGroups.insert(partitionIds.front()); + for (auto kind : castAsyncOpAttrs(putExitOp.getAsyncOps())) { + switch (kind) { + case AsyncOp::TC5MMA: + case AsyncOp::TMALoad: + case AsyncOp::NONE: + count.producerPendingCount += 1; + break; + default: + llvm_unreachable("unsupported producer kind"); + } + } + } else if (auto getExitOp = dyn_cast(user)) { + if (consumerGroups.count(partitionIds.front())) { + continue; + } + consumerGroups.insert(partitionIds.front()); + for (auto kind : castAsyncOpAttrs(getExitOp.getAsyncOps())) { + switch (kind) { + case AsyncOp::TC5MMA: + case AsyncOp::WGMMA: + case AsyncOp::NONE: + count.consumerPendingCount += 1; + break; + default: + llvm_unreachable("unsupported consumer kind"); + } + } + } + } + // If the aref is not used within a warp-specialized loop, the pending counts + // will be equal 0. Set them to 1. + if (count.consumerPendingCount == 0) + count.consumerPendingCount = 1; + if (count.producerPendingCount == 0) + count.producerPendingCount = 1; + + return count; +} + +Value createBarriers(ImplicitLocOpBuilder &b1, ImplicitLocOpBuilder &b2, + int numBarriers, int arrivalCount) { + Value barrierAlloc = createScalarAlloc(b1, b1.getI64Type(), numBarriers); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(b1, barrierAlloc, i); + InitBarrierOp::create(b1, barrierView, arrivalCount); + } + // Invalidate and deallocate the barriers. + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(b2, barrierAlloc, i); + InvalBarrierOp::create(b2, barrierView); + } + LocalDeallocOp::create(b2, barrierAlloc); + return barrierAlloc; +} + +ArefValue createAndInitMbar(ArefCreateOp op, PatternRewriter &rewriter) { + BarrierCount count = getArrivalCount(op); + + auto arefTy = op.getType(); + auto arefBufTypes = llvm::to_vector(llvm::map_range( + arefTy.getBaseType(), [](Type type) { return cast(type); })); + auto depth = getArefDepth(arefBufTypes[0]); + + SetVector arefUsers; + for (auto user : op->getUsers()) + arefUsers.insert(user); + auto sorted = topologicalSort(arefUsers); + + ImplicitLocOpBuilder b1(op->getLoc(), op), b2(op->getLoc(), op); + auto op1 = op->getBlock()->findAncestorOpInBlock(*sorted.back()); + b2.setInsertionPointAfter(op1); + + auto emptyMbars = createBarriers(b1, b2, depth, count.consumerPendingCount); + auto fullMbars = createBarriers(b1, b2, depth, count.producerPendingCount); + + return ArefValue{emptyMbars, fullMbars, static_cast(depth), + op.getOperands()}; +} + +SmallVector +getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter, + std::optional partitionWsTagIds, + StageCluster stageCluster) { + SmallVector views; + for (auto buffer : arefVal.buffers) { + auto memDescType = cast(buffer.getType()); + if (isa( + memDescType.getEncoding())) { + // tmem scales encoding doesn't support multi-buffering, use buffer as-is + views.push_back(buffer); + } else { + auto shape = memDescType.getShape(); + SmallVector tensorShape(shape.begin() + 1, shape.end()); + auto memDescTypeNew = MemDescType::get( + tensorShape, memDescType.getElementType(), memDescType.getEncoding(), + memDescType.getMemorySpace(), true); + auto singleBuffer = + MemDescIndexOp::create(rewriter, loc, memDescTypeNew, buffer, stage); + assignStageCluster(singleBuffer, partitionWsTagIds, stageCluster, + rewriter); + views.push_back(singleBuffer); + } + } + + return views; +} + +void createTMALoad(triton::nvws::DescriptorLoadOp op, PatternRewriter &rewriter, + Value barrierAlloc, Value pred) { + auto newLoadOp = triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( + rewriter, op.getLoc(), op.getDesc(), op.getIndices(), barrierAlloc, + op.getResult(), pred); + assignStageCluster(newLoadOp, getPartitionWsTagIds(op), getStageCluster(op), + rewriter); +}; + +void createTMAGather(triton::nvws::DescriptorGatherOp op, + PatternRewriter &rewriter, Value barrierAlloc, + Value pred) { + auto newGatherOp = triton::nvidia_gpu::AsyncTMAGatherOp::create( + rewriter, op.getLoc(), op.getDesc(), op.getXOffsets(), op.getYOffset(), + barrierAlloc, op.getResult(), pred); + assignStageCluster(newGatherOp, getPartitionWsTagIds(op), getStageCluster(op), + rewriter); +} + +void lowerTMALoad(ArefPutEnterOp op, Value fullBarrier, + PatternRewriter &rewriter, ArefValue arefVal) { + auto loc = op.getLoc(); + int txCount = 0; + // for now handle TMA loads in PutEnterOp + SmallVector loadOps; + for (auto buffer : op.getBuffers()) { + for (auto user : buffer.getUsers()) { + if (auto loadOp = + dyn_cast(user)) { + loadOps.push_back(loadOp); + txCount += loadOp.getTxCount(); + } + } + } + assert(loadOps.size() <= op.getBuffers().size()); + if (loadOps.empty()) + return; + + auto pred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + assignStageCluster(pred, getPartitionWsTagIds(op), getStageCluster(op), + rewriter); + auto expectOp = triton::nvidia_gpu::BarrierExpectOp::create( + rewriter, loc, fullBarrier, txCount, pred); + assignStageCluster(expectOp, getPartitionWsTagIds(op), getStageCluster(op), + rewriter); + + for (auto loadOp : loadOps) { + rewriter.setInsertionPoint(loadOp); + if (auto descLoad = dyn_cast(loadOp)) { + createTMALoad(descLoad, rewriter, fullBarrier, pred); + } else if (auto descGather = + dyn_cast(loadOp)) { + createTMAGather(descGather, rewriter, fullBarrier, pred); + } else { + llvm_unreachable("Unknown load op"); + } + loadOp->erase(); + } +} + +void insertWaitOp(PatternRewriter &rewriter, Operation *op, Value barrier, + Value phase, Value stage) { + auto waitOp = WaitBarrierOp::create(rewriter, op->getLoc(), barrier, phase); + assignStageCluster(waitOp, getPartitionWsTagIds(op), getStageCluster(op), + rewriter); +} + +void rewritePutEnterOp(ArefPutEnterOp op, PatternRewriter &rewriter, + ArefValue arefVal, + const DenseSet &mmav5Ops, + unsigned defaultNumStages) { + auto loc = op.getLoc(); + rewriter.setInsertionPointAfter(op); + + // get empty barrier at a given stage + Value emptyBarrier = + getEmptyBarrier(rewriter, loc, arefVal, op.getStage(), + getPartitionWsTagIds(op), getStageCluster(op)); + + insertWaitOp(rewriter, op, emptyBarrier, op.getPhase(), op.getStage()); + auto views = getSubViews(arefVal, op.getStage(), loc, rewriter, + getPartitionWsTagIds(op), getStageCluster(op)); + assert(views.size() == op.getBuffers().size()); + + // Use the token to find the matching enter / exit pair + // %bufs:n, %token = aref_put.enter %aref[%enter_idx] + // tma_load %bufs[0] + // .. + // tma_load %bufs[n-1] + // aref_put.exit %aref[%exit_idx], %token + ArefPutExitOp exitOp; + for (auto user : op.getToken().getUsers()) { + if (auto op = dyn_cast(user)) { + exitOp = op; + break; + } + } + if (!exitOp) + return; + assert(exitOp.getAref() == op.getAref() && + "Expecting matching Aref on the ArefPutExitOp"); + + auto asyncKinds = castAsyncOpAttrs(exitOp.getAsyncOps()); + auto hasAsyncLoad = [](AsyncOp kind) { + return kind == AsyncOp::TMALoad || kind == AsyncOp::CpAsync; + }; + auto hasTMA = [](AsyncOp kind) { return kind == AsyncOp::TMALoad; }; + + if (llvm::any_of(asyncKinds, hasTMA)) { + Value fullBarrier = + getFullBarrier(rewriter, loc, arefVal, op.getStage(), + getPartitionWsTagIds(op), getStageCluster(op)); + lowerTMALoad(op, fullBarrier, rewriter, arefVal); + } + + if (llvm::any_of(asyncKinds, hasAsyncLoad)) { + for (auto mmav5 : mmav5Ops) { + setIsAsync(mmav5, defaultNumStages); + } + } + + for (auto [oldBuffer, view] : llvm::zip(op.getBuffers(), views)) { + oldBuffer.replaceAllUsesWith(view); + } +} + +static MemDescType getAsMutable(MemDescType type) { + return MemDescType::get(type.getShape(), type.getElementType(), + type.getEncoding(), type.getMemorySpace(), + /*mutableMemory=*/true); +} + +static void propagateMutability(Value value) { + for (Operation *user : value.getUsers()) { + if (user->hasTrait()) { + user->getResult(0).setType( + getAsMutable(cast(user->getResult(0).getType()))); + propagateMutability(user->getResult(0)); + } + } +} + +void rewriteGetEnterOp(ArefGetEnterOp op, PatternRewriter &rewriter, + ArefValue arefVal) { + auto loc = op.getLoc(); + rewriter.setInsertionPointAfter(op); + + Value fullBarrier = + getFullBarrier(rewriter, loc, arefVal, op.getStage(), + getPartitionWsTagIds(op), getStageCluster(op)); + insertWaitOp(rewriter, op, fullBarrier, op.getPhase(), op.getStage()); + auto views = getSubViews(arefVal, op.getStage(), loc, rewriter, + getPartitionWsTagIds(op), getStageCluster(op)); + assert(views.size() == op.getBuffers().size()); + + for (auto [oldBuffer, view] : llvm::zip(op.getBuffers(), views)) { + oldBuffer.replaceAllUsesWith(view); + // Before aref lowering, memdesc_trans consumes an immutable buffer from + // a get enter op. After lowering, all buffers are mutable. + propagateMutability(view); + } +} + +void rewriteArefBufferOp(ArefBufferOp op, PatternRewriter &rewriter, + ArefValue arefVal) { + auto loc = op->getLoc(); + rewriter.setInsertionPointAfter(op); + auto views = getSubViews(arefVal, op.getStage(), loc, rewriter, + getPartitionWsTagIds(op), getStageCluster(op)); + assert(views.size() == op.getBuffers().size()); + for (int i = 0; i < op.getBuffers().size(); ++i) + op.getBuffers()[i].replaceAllUsesWith(views[i]); +} + +void insertArriveBarrier(Location loc, ArrayRef asyncOps, + PatternRewriter &rewriter, Value mbar, + std::optional partitionWsTagIds, + StageCluster stageCluster) { + for (auto asyncOpEnum : asyncOps) { + Operation *arriveOp = {}; + switch (asyncOpEnum) { + case AsyncOp::NONE: + case AsyncOp::WGMMA: + arriveOp = nvidia_gpu::ArriveBarrierOp::create(rewriter, loc, mbar, 1); + break; + case AsyncOp::TC5MMA: + case AsyncOp::TMEMCopy: + arriveOp = nvidia_gpu::TCGen5CommitOp::create(rewriter, loc, mbar, + Value(), ValueRange{}); + break; + case AsyncOp::TMALoad: + // nothing to do, the arrive is done by HW + break; + case AsyncOp::CpAsync: + default: + llvm_unreachable("unknown async op"); + } + if (arriveOp) + assignStageCluster(arriveOp, partitionWsTagIds, stageCluster, rewriter); + } +} + +void rewritePutExitOp(ArefPutExitOp op, PatternRewriter &rewriter, + ArefValue arefVal) { + auto loc = op->getLoc(); + auto stageCluster = getStageCluster(op); + auto asyncKinds = castAsyncOpAttrs(op.getAsyncOps()); + rewriter.setInsertionPointAfter(op); + + bool needFence = [&]() { + bool isGenericProxy = llvm::any_of( + asyncKinds, [](AsyncOp kind) { return kind == AsyncOp::NONE; }); + if (!isGenericProxy) { + return false; + } + auto tmem = TensorMemorySpaceAttr::get(op.getContext()); + auto arefType = cast(op.getAref().getType()); + // Currently we assume that an aref does not contain both SMEM and TMEM. + // So checking only the first buffer is fine. + auto arefBufType = cast(arefType.getBaseType()[0]); + if (arefBufType.getMemorySpace() == tmem) { + return false; + } + for (auto arefUser : op.getAref().getUsers()) { + if (auto getExit = dyn_cast(arefUser)) { + bool isConsumerMMAv5 = + llvm::any_of(castAsyncOpAttrs(getExit.getAsyncOps()), + [](AsyncOp kind) { return kind == AsyncOp::TC5MMA; }); + if (isConsumerMMAv5) { + return true; + } + } + } + return false; + }(); + + if (needFence) { + auto fence = FenceAsyncSharedOp::create(rewriter, loc, /*bCluster=*/false); + assignStageCluster(fence, getPartitionWsTagIds(op), stageCluster, rewriter); + } + + Value fullBarrier = + getFullBarrier(rewriter, loc, arefVal, op.getStage(), + getPartitionWsTagIds(op), getStageCluster(op)); + insertArriveBarrier(loc, castAsyncOpAttrs(op.getAsyncOps()), rewriter, + fullBarrier, getPartitionWsTagIds(op), + getStageCluster(op)); +} + +void rewriteGetExitOp(ArefGetExitOp op, PatternRewriter &rewriter, + ArefValue arefVal) { + auto loc = op->getLoc(); + auto stageCluster = getStageCluster(op); + auto asyncKinds = castAsyncOpAttrs(op.getAsyncOps()); + rewriter.setInsertionPointAfter(op); + + bool needFence = [&]() { + bool isGenericProxy = llvm::any_of( + asyncKinds, [](AsyncOp kind) { return kind == AsyncOp::NONE; }); + if (!isGenericProxy) { + return false; + } + for (auto arefUser : op.getAref().getUsers()) { + if (auto putExit = dyn_cast(arefUser)) { + bool isProducerTMA = + llvm::any_of(castAsyncOpAttrs(putExit.getAsyncOps()), + [](AsyncOp kind) { return kind == AsyncOp::TMALoad; }); + if (isProducerTMA) { + return true; + } + } + } + return false; + }(); + + if (needFence) { + auto fence = FenceAsyncSharedOp::create(rewriter, loc, /*bCluster=*/false); + assignStageCluster(fence, getPartitionWsTagIds(op), stageCluster, rewriter); + } + + Value emptyBarrier = + getEmptyBarrier(rewriter, loc, arefVal, op.getStage(), + getPartitionWsTagIds(op), getStageCluster(op)); + insertArriveBarrier(loc, asyncKinds, rewriter, emptyBarrier, + getPartitionWsTagIds(op), stageCluster); +} + +DenseSet getAsyncMMAv5Consumers(Value aref) { + DenseSet mmav5Ops; + for (auto arefUser : aref.getUsers()) { + if (auto getEnter = dyn_cast(arefUser)) { + if (hasPartition(getEnter) && getPartitionIds(getEnter).front() == 0) { + // Ignore mmav5 ops in the default partition. They are not warp + // specialized. + continue; + } + + for (auto consumer : getEnter->getUsers()) { + if (auto mmav5 = dyn_cast(consumer)) { + mmav5Ops.insert(mmav5); + } else if (auto forOp = consumer->getParentOfType()) { + auto users = + getTopLevelUsersInLoop(consumer, forOp, [](Operation *user) { + return isa(user); + }); + for (auto user : users) { + mmav5Ops.insert(cast(user)); + } + } + } + } + } + return mmav5Ops; +} + +class LowerArefCreate : public OpRewritePattern { +public: + LowerArefCreate(MLIRContext *ctx, unsigned defaultNumStages) + : OpRewritePattern(ctx), defaultNumStages(defaultNumStages) {} + + LogicalResult matchAndRewrite(ArefCreateOp op, + PatternRewriter &rewriter) const override { + auto aref = createAndInitMbar(op, rewriter); + SetVector opToDelete; + opToDelete.insert(op.getOperation()); + + // setIsAsync(true) will be invoked on these mmav5 ops during + // rewritePutEnterOp when the producer is async loads. Since collecting + // consumer mmav5 ops requires the corresponding get enter op to be still + // used in the IR, collect them here. + auto mmav5Ops = getAsyncMMAv5Consumers(op.getResult()); + + for (auto userOp : op->getUsers()) { + opToDelete.insert(userOp); + if (auto user = dyn_cast(userOp)) { + rewritePutEnterOp(user, rewriter, aref, mmav5Ops, defaultNumStages); + } else if (auto user = dyn_cast(userOp)) { + rewriteGetEnterOp(user, rewriter, aref); + } else if (auto user = dyn_cast(userOp)) { + rewritePutExitOp(user, rewriter, aref); + } else if (auto user = dyn_cast(userOp)) { + rewriteGetExitOp(user, rewriter, aref); + } else if (auto user = dyn_cast(userOp)) { + rewriteArefBufferOp(user, rewriter, aref); + } else { + llvm_unreachable("users of aref can only be ArefPut or ArefGet"); + } + } + + auto sorted = topologicalSort(opToDelete); + OpBuilder b(op); + auto replToken = + ub::PoisonOp::create(b, op.getLoc(), b.getType()); + for (auto op : sorted) { + if (auto enterOp = dyn_cast(op)) + enterOp.getToken().replaceAllUsesWith(replToken); + else if (auto enterOp = dyn_cast(op)) + enterOp.getToken().replaceAllUsesWith(replToken); + } + for (auto it = sorted.rbegin(); it != sorted.rend(); ++it) + rewriter.eraseOp(*it); + + return success(); + } + +private: + unsigned defaultNumStages; +}; + +bool isProducerLoad(ArefCreateOp arefOp) { + for (auto user : arefOp.getResult().getUsers()) { + if (auto putOp = dyn_cast(user)) { + if (llvm::any_of(putOp->getUsers(), [](auto user) { + return isa(user); + })) { + return true; + } + } + } + return false; +} + +void multiBufferAref(const SmallVector &arefOps, int numStages) { + SmallVector allocsToErase; + for (auto arefOp : arefOps) { + SmallVector allocOps; + SmallVector arefTypes; + + bool eligible = true; + for (auto opnd : arefOp.getOperands()) { + if (!opnd.getDefiningOp() || isa(opnd.getDefiningOp())) { + eligible = false; + } + } + + if (!eligible) { + continue; + } + + OpBuilder builder(arefOp); + for (auto opnd : arefOp.getOperands()) { + auto oldAlloc = opnd.getDefiningOp(); + auto arefBufType = cast(opnd.getType()); + arefBufType = + getMultiBufferedType(getBufferViewType(arefBufType, true), numStages); + Operation *newAlloc = triton::nvws::createAlloc( + builder, oldAlloc->getLoc(), arefBufType, Value()); + allocOps.push_back(newAlloc->getResult(0)); + arefTypes.push_back(arefBufType); + oldAlloc->replaceAllUsesWith(newAlloc); + allocsToErase.push_back(oldAlloc); + } + + auto newAref = + createArefCreateOp(builder, arefTypes, allocOps, arefOp.getLoc()); + + arefOp.getResult().replaceAllUsesWith(newAref.getResult()); + arefOp.erase(); + } + + for (auto alloc : allocsToErase) { + alloc->erase(); + } +} + +template +ExitOp createCombinedArefOps(SmallVector &enterOps, + SmallVector &exitOps, ArefCreateOp aref, + OpBuilder &builder, + Operation *combinedEnterInsertPoint = nullptr) { + auto firstEnter = *llvm::min_element(enterOps, [](EnterOp a, EnterOp b) { + assert(a->getBlock() == b->getBlock()); + return a->isBeforeInBlock(b); + }); + + auto lastExit = *llvm::max_element(exitOps, [](ExitOp a, ExitOp b) { + assert(a->getBlock() == b->getBlock()); + return a->isBeforeInBlock(b); + }); + + SmallVector arefEnterBuffers; + for (auto enterOp : enterOps) { + arefEnterBuffers.push_back(enterOp.getResult(0).getType()); + } + + llvm::SmallSetVector opAttrsSet; + for (ExitOp exitOp : exitOps) { + opAttrsSet.insert(exitOp.getAsyncOps().begin(), exitOp.getAsyncOps().end()); + } + + builder.setInsertionPointAfter(aref); + auto zero = arith::ConstantIntOp::create(builder, aref.getLoc(), 0, 32); + assignStageCluster(zero, getPartitionWsTagIds(firstEnter), + getStageCluster(firstEnter), builder); + + if (combinedEnterInsertPoint) { + // Combined get enter must be placed after combined put enter + builder.setInsertionPointAfter(combinedEnterInsertPoint); + } else { + builder.setInsertionPoint(firstEnter); + } + auto combinedEnter = + EnterOp::create(builder, firstEnter.getLoc(), arefEnterBuffers, + builder.getType(), aref, zero, zero); + assignStageCluster(combinedEnter, getPartitionWsTagIds(firstEnter), + getStageCluster(firstEnter), builder); + + builder.setInsertionPoint(lastExit); + llvm::SmallVector AsyncOpAttrs(opAttrsSet.begin(), + opAttrsSet.end()); + auto combinedExit = ExitOp::create(builder, firstEnter.getLoc(), aref, + combinedEnter.getToken(), zero, + builder.getArrayAttr(AsyncOpAttrs)); + assignStageCluster(combinedExit, getPartitionWsTagIds(lastExit), + getStageCluster(lastExit), builder); + + std::function moveUserAfter = + [&](Operation *op, Operation *target) { + auto curBlock = target->getBlock(); + for (auto user : op->getUsers()) { + auto userOp = curBlock->findAncestorOpInBlock(*user); + if (userOp->isBeforeInBlock(target)) { + userOp->moveAfter(target); + moveUserAfter(userOp, userOp); + } + } + }; + + for (auto [idx, enterOp] : llvm::enumerate(enterOps)) { + moveUserAfter(enterOp, combinedEnter); + enterOp.getBuffers()[0].replaceAllUsesWith(combinedEnter.getBuffers()[idx]); + } + + return combinedExit; +} + +SmallVector findSharedMemorySinkOps(Value value) { + SmallVector sinkOps; + for (Operation *user : value.getUsers()) { + if (isa(user)) { + sinkOps.push_back(user); + } else if (user->hasTrait()) { + auto rec = findSharedMemorySinkOps(user->getResult(0)); + sinkOps.insert(sinkOps.end(), rec.begin(), rec.end()); + } + } + return sinkOps; +} + +Operation *getDominantConsumer(ArefGetEnterOp getEnterOp, Block &container, + DominanceInfo &domInfo) { + assert(getEnterOp->getNumResults() && "Expect a single-result ArefGenterOp"); + auto buf = getEnterOp->getResult(0); + SmallVector sinkOps = findSharedMemorySinkOps(buf); + if (sinkOps.empty()) { + return nullptr; + } + Operation *liveBeforeOp = findNearestCommonDominator(sinkOps, domInfo); + return container.findAncestorOpInBlock(*liveBeforeOp); +} + +// This is an optimization to combine arefs for TMA load into one, so that +// barrier arrive and wait are coalesced. +void combineArefs(scf::ForOp loop) { + // We combine getEnterOps in the same loop body, not across a loop. + auto getEnterOps = loop.getOps(); + + // Arefs whose get-enter ops share the same dominant consumer can be combined + DominanceInfo domInfo(loop); + llvm::DenseMap, SmallVector> + liveBeforeGroups; + for (auto getEnterOp : getEnterOps) { + if (auto liveBeforeOp = + getDominantConsumer(getEnterOp, *loop.getBody(), domInfo)) { + assert(hasPartition(getEnterOp)); + auto partitionIds = getPartitionIds(getEnterOp); + assert(partitionIds.size() == 1); + liveBeforeGroups[{liveBeforeOp, partitionIds.front()}].push_back( + getEnterOp); + } + } + + for (auto getEnterOps : llvm::make_second_range(liveBeforeGroups)) { + if (getEnterOps.size() == 1) { + continue; + } + + SmallVector arefs; + for (auto getEnterOp : getEnterOps) { + arefs.push_back(cast(getEnterOp.getAref().getDefiningOp())); + } + + SmallVector putEnterOps; + SmallVector putExitOps; + SmallVector getExitOps; + SmallVector producerGroupIds; + for (auto aref : arefs) { + for (auto user : aref->getUsers()) { + if (auto putEnterOp = dyn_cast(user)) { + putEnterOps.push_back(putEnterOp); + producerGroupIds.push_back(getPartitionIds(putEnterOp).front()); + } else if (auto putExitOp = dyn_cast(user)) { + putExitOps.push_back(putExitOp); + } else if (auto getExitOp = dyn_cast(user)) { + getExitOps.push_back(getExitOp); + } + } + } + + // Producer arefs must be in the same partition. + if (llvm::any_of(producerGroupIds, + [&](auto id) { return id != producerGroupIds[0]; })) { + continue; + } + + SmallVector arefBufTypes; + SmallVector arefBufs; + for (auto aref : arefs) { + arefBufTypes.push_back(aref.getOperands()[0].getType()); + arefBufs.push_back(aref.getOperands()[0]); + } + + // set insertion point at the last aref_create + auto lastAref = *llvm::max_element(arefs, [](auto a, auto b) { + assert(a->getBlock() == b->getBlock()); + return a->isBeforeInBlock(b); + }); + + OpBuilder builder(lastAref); + auto aref = + createArefCreateOp(builder, arefBufTypes, arefBufs, lastAref->getLoc()); + + auto combinedPutExit = + createCombinedArefOps(putEnterOps, putExitOps, aref, builder); + createCombinedArefOps(getEnterOps, getExitOps, aref, builder, + combinedPutExit); + + for (auto putExitOp : putExitOps) + putExitOp->erase(); + for (auto putEnterOp : putEnterOps) + putEnterOp->erase(); + for (auto getExitOp : getExitOps) + getExitOp->erase(); + for (auto getEnterOp : getEnterOps) + getEnterOp->erase(); + for (auto aref : arefs) + aref->erase(); + } +} + +void hoistPoissonOps(triton::FuncOp funcOp) { + SmallVector poisonOps; + auto block = &funcOp.getBody().front(); + funcOp.walk([&](ub::PoisonOp op) { op->moveBefore(&block->front()); }); +} +} // anonymous namespace + +class NVWSLowerAref : public impl::NVWSLowerArefBase { + using impl::NVWSLowerArefBase::NVWSLowerArefBase; + +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp m = getOperation(); + + SmallVector loops; + m.walk([&](scf::ForOp loop) { + if (loop->hasAttr(triton::kWarpSpecializeAttrName)) { + loop->walk([&](scf::ForOp op) { loops.push_back(op); }); + } + }); + + for (scf::ForOp loop : loops) { + combineArefs(loop); + } + + SmallVector arefOps; + m.walk([&](ArefCreateOp arefOp) { + // Only handles arefs whose producer (a partition with PutEnter / Exit) + // does load from global to shared memory. + if (isProducerLoad(arefOp)) { + arefOps.push_back(arefOp); + } + }); + multiBufferAref(arefOps, numStages); + + OpPassManager pm; + pm.addPass(createNVWSAssignStagePhase()); + if (failed(runPipeline(pm, m))) + return signalPassFailure(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context, numStages); + GreedyRewriteConfig config; + config.enableConstantCSE(false); + config.enableFolding(false); + if (applyPatternsGreedily(m, std::move(patterns), config).failed()) + signalPassFailure(); + + // Hoist all poison ops to the top of function from nvws.wg regions. + // They are unannotated and will trip subsequent passes, same to hoist. + m.walk([&](triton::FuncOp funcOp) { hoistPoissonOps(funcOp); }); + } +}; // namespace triton + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerWarpGroup.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerWarpGroup.cpp new file mode 100644 index 0000000000..3bb26165e6 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/LowerWarpGroup.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/LogicalResult.h" +#include +#include + +using namespace mlir::triton; +using namespace mlir::triton::nvws; +using namespace mlir::triton::gpu; + +#define DEBUG_TYPE "nvws-lower-warp-group" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_NVWSLOWERWARPGROUP +#include "triton/Dialect/NVWS/Transforms/Passes.h.inc" + +namespace { + +class LowerWarpGroup : public OpRewritePattern { + + void populateRegion(PatternRewriter &rewriter, Region *inputRegion, + Region *outputRegion, SmallVector &inputs, + IRMapping &mapping) const { + Block *output_block = &outputRegion->emplaceBlock(); + DenseMap valueMap; + rewriter.setInsertionPointToEnd(output_block); + + for (auto &value : inputs) { + auto new_value = + output_block->addArgument(value.getType(), value.getLoc()); + valueMap[value] = new_value; + } + auto retOp = triton::gpu::WarpReturnOp::create( + rewriter, inputRegion->getLoc(), ArrayRef(), ArrayRef()); + + for (auto &op : llvm::make_early_inc_range( + inputRegion->getBlocks().front().without_terminator())) { + op.moveBefore(retOp); + } + + for (auto pair : valueMap) + replaceAllUsesInRegionWith(pair.first, pair.second, *outputRegion); + } + + LogicalResult createWarpSpecializeOp(Location loc, WarpGroupOp warpGroupOp, + PatternRewriter &rewriter, + Region *defaultPartition, + RegionRange partitions, + ArrayRef numWarps, + ArrayRef warpGroupStartIds) const { + if (partitions.size() != numWarps.size()) + return failure("mismatched number of warp groups and number of warps per " + "warp group"); + if (partitions.size() != warpGroupStartIds.size()) + return failure( + "mismatched number of warp groups and number of warp start ids"); + + SetVector captures; + for (auto partition : partitions) + mlir::getUsedValuesDefinedAbove(*partition, captures); + + SmallVector inputs; + SmallVector mappings(partitions.size()); + SmallVector builders; + for (auto region : partitions) { + builders.push_back(OpBuilder::atBlockBegin(®ion->front())); + } + + SetVector opsToClone; + std::queue que; + for (auto capture : captures) { + que.push(capture); + } + + while (!que.empty()) { + Value capture = que.front(); + // Rematerialize constants and also pure tensor ops to get around the + // restriction below on capturing tensors. + Operation *defOp = capture.getDefiningOp(); + if (!isa(capture) && defOp && isPure(defOp) && + (defOp->hasTrait() || + isa(capture.getType()))) { + for (auto operand : defOp->getOperands()) { + que.push(operand); + } + opsToClone.insert(defOp); + } else if (auto tensorTy = + dyn_cast(capture.getType())) { + SharedEncodingTrait sharedEnc = getSharedEncoding(tensorTy); + auto memdescTy = MemDescType::get( + tensorTy.getShape(), tensorTy.getElementType(), sharedEnc, + SharedMemorySpaceAttr::get(tensorTy.getContext())); + auto alloc = LocalAllocOp::create(rewriter, loc, memdescTy, capture); + for (auto [i, region] : llvm::enumerate(partitions)) { + Value value = LocalLoadOp::create(builders[i], capture.getLoc(), + tensorTy, alloc); + replaceAllUsesInRegionWith(capture, value, *region); + mappings[i].map(capture, value); + } + inputs.push_back(alloc); + } else { + inputs.push_back(capture); + } + que.pop(); + } + + opsToClone = topologicalSort(opsToClone); + + for (auto [region, b, mapping] : + llvm::zip(partitions, builders, mappings)) { + for (Operation *op : opsToClone) { + auto copy = b.clone(*op, mapping)->getResult(0); + mapping.map(op->getResult(0), copy); + replaceAllUsesInRegionWith(op->getResult(0), copy, *region); + } + } + + auto wsOp = WarpSpecializeOp::create( + rewriter, loc, warpGroupOp.getResultTypes(), numWarps); + + auto &defaultBlock = wsOp.getDefaultRegion().emplaceBlock(); + rewriter.setInsertionPointToEnd(&defaultBlock); + + if (defaultPartition) { + auto yieldOp = defaultPartition->front().getTerminator(); + auto newYieldOp = WarpYieldOp::create( + rewriter, loc, yieldOp->getResultTypes(), yieldOp->getOperands()); + + for (auto &op : llvm::make_early_inc_range( + defaultPartition->getBlocks().front().without_terminator())) { + op.moveBefore(newYieldOp); + } + } else { + WarpYieldOp::create(rewriter, loc, TypeRange(), ArrayRef()); + } + + auto &block = wsOp.getPartitionOpHolder().emplaceBlock(); + rewriter.setInsertionPointToStart(&block); + auto wspOp = WarpSpecializePartitionsOp::create(rewriter, loc, inputs, + partitions.size()); + auto regions = wspOp.getPartitionRegions(); + + for (auto [in, out, mapping] : zip(partitions, regions, mappings)) + populateRegion(rewriter, in, &out, inputs, mapping); + + warpGroupOp.replaceAllUsesWith(wsOp); + + return success(); + } + +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WarpGroupOp warpGroupOp, + PatternRewriter &rewriter) const override { + auto loc = warpGroupOp.getLoc(); + rewriter.setInsertionPointAfter(warpGroupOp); + auto mod = warpGroupOp->getParentOfType(); + int32_t globalNumWarps = + mlir::cast(mod->getAttr("ttg.num-warps")).getInt(); + + auto regions = warpGroupOp.getRegions(); + Region *defaultRegion = nullptr; + int startWarp = 0; + auto numWarps = warpGroupOp.getNumWarps(); + + if (numWarps[0] == globalNumWarps) { + defaultRegion = regions.front(); + regions = regions.drop_front(); + startWarp = globalNumWarps; + numWarps = numWarps.drop_front(); + } else if (warpGroupOp.getNumResults() > 0) { + return failure("The first warp group does not use the default number of " + "warps. The default partition cannot be created. When " + "nvws.warp_group op returns results, there must be a " + "default region."); + } + + auto result = createWarpSpecializeOp( + loc, warpGroupOp, rewriter, defaultRegion, regions, numWarps, + llvm::map_to_vector(numWarps, [&](int numWarps) { + int result = startWarp; + startWarp += numWarps; + return result; + })); + + if (result.succeeded()) + rewriter.eraseOp(warpGroupOp); + + return result; + } +}; + +} // namespace + +class NVWSLowerWarpGroup + : public impl::NVWSLowerWarpGroupBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + GreedyRewriteConfig config; + + if (applyPatternsGreedily(m, std::move(patterns), config).failed()) + signalPassFailure(); + + if (failed(m.verify())) + assert(false); + } +}; + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.cpp b/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.cpp new file mode 100644 index 0000000000..cbca317deb --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.cpp @@ -0,0 +1,65 @@ +#include "Utilities.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; + +namespace mlir::triton::nvws { + +Operation *createAlloc(OpBuilder &builder, Location loc, + MemDescType memDescType, Value src) { + if (isa(memDescType.getMemorySpace())) { + return LocalAllocOp::create(builder, loc, memDescType, src); + } else { + assert(isa(memDescType.getMemorySpace())); + return TMEMAllocOp::create(builder, loc, memDescType, src); + } +} + +ArefCreateOp createArefCreateOp(OpBuilder &builder, ArrayRef arefTypes, + ValueRange allocOps, Location loc) { + auto ctx = builder.getContext(); + auto arefTy = ArefType::get(ctx, TypeArrayAttr::get(ctx, arefTypes)); + return ArefCreateOp::create(builder, loc, arefTy, allocOps); +} + +int getArefDepth(MemDescType bufTy) { + auto shape = bufTy.getShape(); + return isa(bufTy.getEncoding()) + ? 1 + : shape[0]; +} + +MemDescType getArefViewBufferType(MemDescType bufTy) { + auto isScalesEnc = + isa(bufTy.getEncoding()); + auto shape = bufTy.getShape(); + return gpu::MemDescType::get(isScalesEnc ? shape : shape.drop_front(), + bufTy.getElementType(), bufTy.getEncoding(), + bufTy.getMemorySpace(), + /*mutableMemory*/ true, + /*allocShape=*/bufTy.getAllocShape()); +} + +MemDescType getArefMultiBufferedType(MemDescType bufTy, int depth) { + auto shape = bufTy.getShape(); + SmallVector bufferShape(shape.begin(), shape.end()); + if (!isa(bufTy.getEncoding())) + bufferShape.insert(bufferShape.begin(), depth); + return gpu::MemDescType::get(bufferShape, bufTy.getElementType(), + bufTy.getEncoding(), bufTy.getMemorySpace(), + /*mutableMemory*/ true); +} + +scf::ForOp getOuterWSLoop(scf::ForOp innerFor) { + auto wsLoop = innerFor; + while (wsLoop && !wsLoop->hasAttr(triton::kWarpSpecializeAttrName)) { + wsLoop = wsLoop->getParentOfType(); + } + return wsLoop; +} + +} // namespace mlir::triton::nvws diff --git a/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.h b/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.h new file mode 100644 index 0000000000..50293d2497 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/NVWS/Transforms/Utilities.h @@ -0,0 +1,43 @@ +#ifndef NVIDIA_NVWS_TRANSFORMS_UTILITY_H_ +#define NVIDIA_NVWS_TRANSFORMS_UTILITY_H_ + +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::nvws { + +Operation *createAlloc(OpBuilder &builder, Location loc, + gpu::MemDescType memDescType, Value src); + +ArefCreateOp createArefCreateOp(OpBuilder &builder, ArrayRef arefTypes, + ValueRange allocOps, Location loc); + +template +inline std::optional findValuePosInRange(const Range &range, + mlir::Value v) { + for (auto [pos, arg] : llvm::enumerate(range)) { + if (arg == v) + return pos; + } + return {}; +} + +#if 0 +struct PartitionId : std::pair { + PartitionId(int index, int tag) : std::pair(index, tag) {} + int &index() { return first; } + int &tag() { return second; } +}; + +std::optional getPartitionId(Operation *op); +#endif + +gpu::MemDescType getArefViewBufferType(gpu::MemDescType arefBufType); +gpu::MemDescType getArefMultiBufferedType(gpu::MemDescType arefBufType, + int depth); +int getArefDepth(gpu::MemDescType bufTy); + +scf::ForOp getOuterWSLoop(scf::ForOp innerFor); +} // namespace mlir::triton::nvws + +#endif // NVIDIA_NVWS_TRANSFORMS_UTILITY_H_ diff --git a/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 0000000000..1662b94968 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,23 @@ +set(LLVM_TARGET_DEFINITIONS Canonicalize.td) +mlir_tablegen(TritonCanonicalize.inc -gen-rewriters) +add_public_tablegen_target(TritonCanonicalizeIncGen) + +add_triton_library(TritonIR + Dialect.cpp + DiscardableAttributes.cpp + Ops.cpp + Traits.cpp + Types.cpp + OpInterfaces.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonCanonicalizeIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Canonicalize.td b/third_party/mthreads/lib/Dialect/Triton/IR/Canonicalize.td new file mode 100644 index 0000000000..dc37710333 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Canonicalize.td @@ -0,0 +1,17 @@ +#ifndef TT_PATTERNS +#define TT_PATTERNS + +include "mlir/IR/PatternBase.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + +// broadcast(splat(x)) -> splat(x) +def BroadcastSplatPattern : + Pat<(TT_BroadcastOp (TT_SplatOp $x)), + (TT_SplatOp $x)>; + +// broadcast(broadcast(x)) -> broadcast(x) +def BroadcastBroadcastPattern : + Pat<(TT_BroadcastOp (TT_BroadcastOp $x)), + (TT_BroadcastOp $x)>; + +#endif diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 0000000000..9073f423f9 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,77 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +bool TritonInlinerInterface::isLegalToInline(Operation *call, + Operation *callable, + bool wouldBeCloned) const { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void TritonInlinerInterface::handleTerminator(Operation *op, + Block *newDest) const { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + mlir::cf::BranchOp::create(builder, op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void TritonInlinerInterface::handleTerminator(Operation *op, + ValueRange valuesToRepl) const { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); +} + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/DiscardableAttributes.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/DiscardableAttributes.cpp new file mode 100644 index 0000000000..8f4d80ea8a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/DiscardableAttributes.cpp @@ -0,0 +1,17 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton { + +SmallVector +filterDiscardableAttrs(Operation *op, ArrayRef allowList) { + SmallVector propagatedAttrs; + for (auto attrName : allowList) { + Attribute attr = op->getDiscardableAttr(attrName); + if (attr) + propagatedAttrs.emplace_back(attrName, attr); + } + return propagatedAttrs; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/OpInterfaces.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/OpInterfaces.cpp new file mode 100644 index 0000000000..7bebffe61b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/OpInterfaces.cpp @@ -0,0 +1,77 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace triton { +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op) { + TransposeOpInterface transposeOp = cast(op); + auto rank = cast(transposeOp.getSrc().getType()).getRank(); + auto order = transposeOp.getOrder(); + if (static_cast(rank) != order.size()) { + return op->emitError( + "order must have the same size as the rank of the operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return op->emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +// A DotOpInterface operation should have at least three operands. +// The first two operands should share a common dimension, and the result +// should have the dimensions of the two operands that are not shared. +// A DotOpInterface operation can be either 2d or 3d. +// In the 3d case, the first dimension of operands is the batch dimension. +LogicalResult verifyDotOpInterface(Operation *op) { + DotOpInterface dotOp = cast(op); + + if (dotOp->getNumOperands() < 3) + return dotOp->emitOpError("expected at least 3 operands"); + auto aTy = cast(dotOp->getOperand(0).getType()); + auto bTy = cast(dotOp->getOperand(1).getType()); + auto cTy = cast(dotOp->getOperand(2).getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + // Check if all 3d or all 2d + if (aShape.size() != 2 && aShape.size() != 3) + return dotOp->emitOpError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) + return dotOp->emitOpError("expected all operands to have the same rank"); + + // Check for valid A, B input shapes for dot + if (!dotOp.verifyDims()) + return dotOp->emitOpError( + "expected the last dimension of the first operand " + "to be equal to the second-to-last dimension of " + "the second operand"); + + // Check the batch dimension + if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0])) + return dotOp->emitOpError("expected the first dimension of the first " + "operand to be equal to the first dimension of " + "the result"); + // Check the output shape + if (!dotOp.verifyOutputDims()) + return dotOp->emitOpError( + "expected the output shape to be the concatenation of the last " + "dimension of the first operand and the last dimension of the " + "second "); + return success(); +} + +} // namespace impl +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 0000000000..85e366c146 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1490 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +#include "TritonCanonicalize.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + // If the source and result types are the same, we can return the source + // If their layout is different (even if structurally equivalent), we need + // to insert a convert_layout in between as otherwise ::fold complains + // We do this in CanonicalizeConvertFromTranspose + if (getSrc().getType() == getType()) { + return getSrc(); + } + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + // Eliminate splat constant transpose ops. + if (auto attr = + llvm::dyn_cast_if_present(adaptor.getSrc())) + return attr.reshape(getType()); + + return {}; +} + +LogicalResult TransOp::verify() { + auto order = getOrder(); + auto srcTy = cast(getSrc().getType()); + if (order.size() != srcTy.getShape().size()) { + return emitError("order must have the same size as the source tensor"); + } + if (!isPermutationOfIota(order)) { + return emitError("order must be a permutation of 0..n-1"); + } + SmallVector retShape = applyPermutation(srcTy.getShape(), order); + if (retShape != getType().getShape()) { + return emitError( + "result shape must match the permutation of the source shape"); + } + return success(); +} + +LogicalResult +TransOp::inferReturnTypes(MLIRContext *context, std::optional loc, + TransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + + // type is the same as the input + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferTransOpEncoding( + argEncoding, shape, order, retEncoding, loc))) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +bool DotOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +//-- DotScaledOp -- +bool DotScaledOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + auto aKdim = aShape[aShape.size() - 1]; + auto bKdim = bShape[aShape.size() - 2]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (this->getLhsKPack()) + aKdim *= 2; + } + if (this->getBElemType() == ScaleDotElemType::E2M1) { + if (this->getRhsKPack()) + bKdim *= 2; + } + + return aKdim == bKdim; +} + +bool DotScaledOp::verifyOutputDims() { + auto cShape = this->getC().getType().getShape(); + auto oMdim = cShape[cShape.size() - 2]; + auto oNdim = cShape[cShape.size() - 1]; + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + auto adim = aShape[aShape.size() - 2]; + auto bdim = bShape[bShape.size() - 1]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (!this->getLhsKPack()) + adim *= 2; + } + if (this->getBElemType() == ScaleDotElemType::E2M1) { + if (!this->getRhsKPack()) + bdim *= 2; + } + if (adim != oMdim || bdim != oNdim) + return false; + return true; +} + +LogicalResult DotScaledOp::verify() { + auto aShape = this->getA().getType().getShape(); + int64_t rank = aShape.size(); + if (rank < 2) + return this->emitError("operands must be at least 2D"); + + auto k = aShape[rank - 1]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (this->getLhsKPack()) + k *= 2; + } + auto cShape = this->getC().getType().getShape(); + int64_t mDim = cShape[cShape.size() - 2]; + int64_t nDim = cShape[cShape.size() - 1]; + + if (getAScale()) { + auto aScaleShape = getAScale().getType().getShape(); + if (aScaleShape[rank - 2] != mDim) + return this->emitError( + "scales M dimension must match the operand M dimension"); + int scale_factor = + isa(getAScale().getType().getElementType()) ? 16 : 32; + if (aScaleShape[rank - 1] != k / scale_factor) + return this->emitError("scales K dimension must match the operand K " + "divided by the scale factor"); + } + if (getBScale()) { + auto bScaleShape = getBScale().getType().getShape(); + if (bScaleShape[rank - 2] != nDim) + return this->emitError( + "scales N dimension must match the operand N dimension"); + int scale_factor = + isa(getBScale().getType().getElementType()) ? 16 : 32; + if (bScaleShape[rank - 1] != k / scale_factor) + return this->emitError("scales K dimension must match the operand K " + "divided by the scale factor"); + } + return success(); +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start >= end) { + return this->emitOpError() << "start must be less than end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(std::optional loc, RankedTensorType argTy, + Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferReduceOpEncoding( + argEncoding, axis, retEncoding, loc))) { + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult +ReduceOp::inferReturnTypes(MLIRContext *context, std::optional loc, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (failed(inferReduceReturnShape(loc, argTy, retEltTy, axis, + inferredReturnTypes))) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + auto axis = op.getAxis(); + auto firstRank = 0; + for (auto tensorTy : op.getInputTypes()) { + int64_t rank = tensorTy.getRank(); + if (axis < 0 || axis >= rank) + return op.emitOpError() << "axis out of bounds for operand rank " << rank; + if (firstRank == 0) + firstRank = rank; + else if (rank != firstRank) + return op.emitOpError() + << "all operands must have the same rank, but got ranks " + << firstRank << " and " << rank; + } + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementTypeOrSelf(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +template +static llvm::SmallVector getElementTypesImpl(const ValueRange &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ScanOp::build(builder, state, inferredReturnTypes, operands, axis, reverse); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- MapElementwiseOp +LogicalResult MapElementwiseOp::verify() { + if (getOperands().empty()) { + return emitOpError() << "MapElementwiseOp must have at least 1 operand"; + } + if (!llvm::isPowerOf2_32(getPack())) { + return emitOpError() << "Pack must be a power of 2"; + } + return success(); +} + +template +SmallVector repeatInterleave(const SmallVectorImpl &vs, int nRepeat) { + SmallVector result; + result.reserve(vs.size() * nRepeat); + for (auto v : vs) + for (auto _ : llvm::seq(nRepeat)) + result.push_back(v); + return result; +} + +LogicalResult MapElementwiseOp::verifyRegions() { + // Verify signature + auto *firstBlock = &getRegion().getBlocks().front(); + if (firstBlock->getNumArguments() != getNumOperands() * getPack()) { + return emitOpError() << "region has wrong number of arguments"; + } + + auto expectedArgTypes = + repeatInterleave(getElementTypesImpl(getOperands()), getPack()); + if (firstBlock->getArgumentTypes() != expectedArgTypes) { + return emitError() << "argument types did not match"; + } + auto expectedReturnTypes = + repeatInterleave(getElementTypesImpl(getResults()), getPack()); + auto walkRes = getRegion().walk([&](Operation *op) -> WalkResult { + auto memEffects = dyn_cast(op); + // Ban stores as we won't get the redundant masking correct by treating it + // as a scalar. + if (memEffects && memEffects.hasEffect()) { + return op->emitOpError() + << "Stores are not supported inside map_elementwise"; + } + if (isa(op) && + op->getOperandTypes() != expectedReturnTypes) { + return op->emitError() + << "region return does not match map_elementwise result"; + } + return WalkResult::advance(); + }); + return success(!walkRes.wasInterrupted()); +} + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + if (!isa(value)) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- UnsplatOp -- +LogicalResult UnsplatOp::verify() { + auto srcShape = getSrc().getType().getShape(); + if (product(srcShape) != 1) { + return emitError("source tensor must have exactly one element"); + } + return success(); +} + +LogicalResult UnsplatOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto dstTy = cast(operands[0].getType()).getElementType(); + inferredReturnTypes.push_back(dstTy); + return success(); +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferExpandDimsOpEncoding( + argEncoding, axis, retEncoding, loc))) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + Dialect &dialect = srcEnc.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferExpandDimsOpEncoding( + srcEnc, op.getAxis(), newExpandEnc, op.getLoc()))) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = ExpandDimsOp::create(rewriter, op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = BroadcastOp::create( + rewriter, broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- + +void ReshapeOp::build(OpBuilder &builder, OperationState &state, + ArrayRef shape, Value src, bool allowReorder) { + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + Attribute dstEnc; + if (srcEnc) { + auto result = cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, + dstEnc, state.location); + assert(succeeded(result)); + } + auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc); + build(builder, state, dstTy, src, allowReorder); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (op.getEfficientLayout()) + return failure(); + + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // reshape(reshape) -> reshape + if (auto parentReshape = dyn_cast(definingOp)) { + // Allow reorder if either reshape allowed it + const bool allowReorder = + (op.getAllowReorder() || parentReshape.getAllowReorder()); + rewriter.replaceOpWithNewOp(op, op.getType(), + parentReshape.getSrc(), allowReorder, + op.getEfficientLayout()); + return success(); + } + + // reshape(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType() && !getAllowReorder()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (!srcEnc || getAllowReorder()) { + return success(); + } + + // Check that we can infer the dst encoding from the src encoding + // and that the inferred dst encoding is the same as the given dst encoding + Attribute inferredDstEnc; + auto layoutInterface = + cast(&srcEnc.getDialect()); + auto result = layoutInterface->inferReshapeOpEncoding( + srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc()); + if (failed(result)) + return failure(); + return layoutInterface->verifyLayoutsAreEqual( + dstTy.getShape(), inferredDstEnc, dstEnc, getLoc()); +} + +//-- FpToFpOp -- + +// Fold FpToFpOp when the input operand is a constant zero. +OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) { + auto srcVal = getSrc(); + auto dstTy = getType(); + // Fold trivial cast + if (srcVal.getType() == dstTy) { + return srcVal; + } + + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &semantic = resElemType.getFloatSemantics(); + + if (matchPattern(srcVal, m_PosZeroFloat())) { + llvm::APFloat posZero = + llvm::APFloat::getZero(semantic, /*negative=*/false); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, posZero); + return Builder(getContext()).getFloatAttr(resElemType, posZero); + } + + if (matchPattern(srcVal, m_NegZeroFloat())) { + llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, negZero); + return Builder(getContext()).getFloatAttr(resElemType, negZero); + } + + return {}; +} + +LogicalResult FpToFpOp::verify() { + auto dstType = getType(); + auto srcType = getSrc().getType(); + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BitcastOp -- +LogicalResult BitcastOp::verify() { + // Bitcast only allows conversion between types with the same bit width. + Type dstType = getType(); + Type srcType = getSrc().getType(); + // Strip tensor shapes; SameOperandsAndResultShape guarantees shapes match. + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); + bool dstIsPtr = isa(dstType); + bool srcIsPtr = isa(srcType); + if (dstIsPtr || srcIsPtr) { + // Bitcast supports pointer-to-pointer conversions but not + // pointer-to-scalar. + if (dstIsPtr && srcIsPtr) { + if (triton::getAddressSpace(dstType) != triton::getAddressSpace(srcType)) + return emitError( + "Cannot bitcast pointer between different address spaces"); + return success(); + } + return emitError("Cannot bitcast pointer to non-pointer type"); + } + unsigned dstBits = dstType.getIntOrFloatBitWidth(); + unsigned srcBits = srcType.getIntOrFloatBitWidth(); + if (dstBits != srcBits) { + return emitError("Cannot bitcast data-type of size ") + << srcBits << " to data-type of size " << dstBits; + } + return success(); +} + +//-- BroadcastOp -- +void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, pointerType.getAddressSpace()); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +//-- AddPtrOp -- +OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { + // addptr(ptr, 0) -> ptr + if (matchPattern(adaptor.getOffset(), m_Zero())) { + return getPtr(); + } + return {}; +} + +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape, bool isSignedInteger, + triton::PaddingOption padding) { + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = + TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); + auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding); + return build(builder, state, descTy, base, shape, strides, paddingAttr); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + call_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/{}, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- + +void JoinOp::build(OpBuilder &builder, OperationState &state, Value lhs, + Value rhs) { + auto lhsTy = cast(lhs.getType()); + SmallVector retShape(lhsTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = lhsTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (failed(cast(&srcEnc.getDialect()) + ->inferDefaultJoinOpEncoding( + srcEnc, retEnc, lhsTy.getShape(), state.location))) { + llvm_unreachable("failed to infer join encoding"); + } + } + auto retTy = RankedTensorType::get(retShape, lhsTy.getElementType(), retEnc); + JoinOp::build(builder, state, retTy, lhs, rhs); +} + +LogicalResult JoinOp::verify() { + RankedTensorType srcTy = getLhs().getType(); + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + RankedTensorType retTy = getType(); + if (SmallVector(retTy.getShape()) != retShape) { + return emitOpError("result shape must be (") + << retShape << "), but got " << retTy.getShape(); + } + if (retTy.getElementType() != srcTy.getElementType()) { + return emitOpError("result element type must match the input element type"); + } + Attribute retEnc = retTy.getEncoding(); + if (!retEnc) { + if (srcTy.getEncoding()) { + return emitOpError("result encoding must be specified"); + } + return success(); + } + // There are multiple correct destination layout for a given source layout but + // there is only one correct source layout for a given destination layout. So + // we verify that the source layout match the destination layout. + Attribute srcEnc; + Location location = getLoc(); + if (cast(&retEnc.getDialect()) + ->inferSplitOpEncoding(retEnc, srcEnc, retShape, location) + .failed()) { + return failure(); + } + + if (cast(&srcEnc.getDialect()) + ->verifyLayoutsAreEqual(srcTy.getShape(), srcEnc, srcTy.getEncoding(), + {}) + .failed()) { + return emitOpError("incompatible join layout"); + } + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, + SplitOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + auto srcTy = cast(adaptor.getSrc().getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); +} + +Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); +} + +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (uint32_t dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back(indicesType.clone(srcType.getElementType())); + return success(); +} + +// -- DescriptorGatherOp +static LogicalResult verifyGatherScatterResultType(Operation *op, + ShapedType resultType, + ShapedType indicesType) { + if (indicesType.getRank() != 1) + return op->emitOpError("x offsets must be a 1D tensor, but got ") + << indicesType; + if (resultType.getRank() != 2) + return op->emitOpError("result must be a 2D tensor, but got ") + << resultType; + + // The swizzling of TMA accesses matches that of the MMAv3 shared memory + // layouts. However, these have minimum size requirements. + // TODO: We can support smaller gather sizes by padding the `local_alloc` this + // lowers to to the nearest minimum tile size. + if (unsigned rows = resultType.getShape()[0]; rows < 8) { + return op->emitOpError("must have at least 8 rows, but got ") << rows; + } + + Type dtype = resultType.getElementType(); + if (dtype.getIntOrFloatBitWidth() > 32) + return op->emitOpError("TMA dtype cannot be greater than 32 bits"); + + unsigned minCols = 32 / dtype.getIntOrFloatBitWidth() * 8; + if (unsigned cols = resultType.getShape()[1]; cols < minCols) { + return op->emitOpError("must have at least ") + << minCols << " columns for " << dtype << ", but got " << cols; + } + + if (resultType.getShape()[0] != indicesType.getShape()[0]) { + return op->emitOpError("result tensor must have as many rows as indices (") + << indicesType.getShape()[0] << "), but got " << resultType; + } + + return success(); +} + +LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType, + ShapedType resultType, + ShapedType indicesType) { + // Gather from `!tt.tensordesc>`. + if (blockType.getRank() != 2) { + return op->emitOpError("descriptor block must be a 2D tensor, but got ") + << blockType; + } + if (blockType.getShape()[0] != 1) { + return op->emitOpError("descriptor block must have exactly 1 row, but got ") + << blockType; + } + + // With x offsets `tensor` into `tensor`. + if (failed(verifyGatherScatterResultType(op, resultType, indicesType))) + return failure(); + + if (resultType.getShape()[1] != blockType.getShape()[1]) { + return op->emitOpError("result tensor number of columns must match block (") + << blockType.getShape()[1] << "), but got " << resultType; + } + if (resultType.getElementType() != blockType.getElementType()) { + return op->emitOpError("result tensor element type must match block (") + << blockType.getElementType() << "), but got " << resultType; + } + + return success(); +} + +LogicalResult DescriptorGatherOp::verify() { + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + getResult().getType(), getXOffsets().getType()); +} + +// -- DescriptorScatterOp -- +LogicalResult DescriptorScatterOp::verify() { + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + getSrc().getType(), getXOffsets().getType()); +} + +// -- DescriptorLoadOp -- +LogicalResult verifyDescriptorLoadStoreOp(Operation *op, + TensorDescInterface desc, + ShapedType tensor) { + RankedTensorType block = desc.getSignlessBlockType(); + if (block.getElementType() != tensor.getElementType()) { + return op->emitOpError("descriptor block and tensor element types must " + "match, but got descriptor element type ") + << block.getElementType() << " and tensor element type " + << tensor.getElementType(); + } + + ArrayRef blockShape = block.getShape(); + ArrayRef tensorShape = tensor.getShape(); + unsigned blockNumels = product(blockShape); + unsigned tensorNumels = product(tensorShape); + if (blockNumels != tensorNumels) { + return op->emitOpError("descriptor block and tensor must have the same " + "number of elements, but got descriptor block " + "with ") + << blockNumels << " elements tensor with " << tensorNumels + << " elements"; + } + return success(); +} + +LogicalResult DescriptorLoadOp::verify() { + return verifyDescriptorLoadStoreOp(*this, getDesc().getType(), getType()); +} + +// -- DescriptorStoreOp -- +LogicalResult DescriptorStoreOp::verify() { + return verifyDescriptorLoadStoreOp(*this, getDesc().getType(), + getSrc().getType()); +} + +// -- DescriptorReduceOp -- +LogicalResult DescriptorReduceOp::verify() { + return verifyDescriptorLoadStoreOp(*this, getDesc().getType(), + getSrc().getType()); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 0000000000..857ff4aad7 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,265 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { + auto memdescA = dyn_cast(typeA); + auto memdescB = dyn_cast(typeB); + if (memdescA || memdescB) { + if (!memdescA || !memdescB) + return failure(); + if (memdescA.getShape() != memdescB.getShape()) + return failure(); + if (memdescA.getAllocShape() != memdescB.getAllocShape()) + return failure(); + if (memdescA.getElementType() != memdescB.getElementType()) + return failure(); + if (memdescA.getMemorySpace() != memdescB.getMemorySpace()) + return failure(); + if (memdescA.getMutableMemory() != memdescB.getMutableMemory()) + return failure(); + + Attribute encodingA = memdescA.getEncoding(); + Attribute encodingB = memdescB.getEncoding(); + if (encodingA == encodingB) + return success(); + if (static_cast(encodingA) != static_cast(encodingB)) + return failure(); + + auto layoutInterface = + cast(&encodingA.getDialect()); + return layoutInterface->verifyLayoutsAreEqual(memdescA.getShape(), + encodingA, encodingB, {}); + } + auto tensorTypeA = dyn_cast(typeA); + auto tensorTypeB = dyn_cast(typeB); + if (!(bool(tensorTypeA) && bool(tensorTypeB))) + return typeA == typeB ? success() : failure(); + auto encodingA = tensorTypeA.getEncoding(); + auto encodingB = tensorTypeB.getEncoding(); + auto shapeA = tensorTypeA.getShape(); + auto shapeB = tensorTypeB.getShape(); + if (shapeA != shapeB) + return failure(); + if (tensorTypeA.getElementType() != tensorTypeB.getElementType()) + return failure(); + // If there's no encoding or the encodings are the same + if (encodingA == encodingB) + return success(); + if (bool(encodingA) != bool(encodingB)) + return failure(); + + return cast(&encodingA.getDialect()) + ->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {}); +} + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (rankedTy) { + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + Dialect &dialect = layout.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, + makeErr); + } + return success(); + } + + auto memDescTy = dyn_cast(val.getType()); + if (!memDescTy) + return success(); + + mlir::Attribute layout = memDescTy.getEncoding(); + if (!layout) + return success(); + + Dialect &dialect = layout.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyMemDescLayout(layout, memDescTy, op, + makeErr); + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 0000000000..179fc5ae9b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,140 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#include "triton/Dialect/Triton/IR/TypeInterfaces.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return tensorTy.clone(i1Type); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return tensorTy.clone(pointeeType); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return tensorTy.clone(i32Type); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + PointerType ptrType = PointerType::get(elementType, 1); + return tensorTy.clone(ptrType); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Utility.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Utility.cpp new file mode 100644 index 0000000000..5e07d5fb81 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Utility.cpp @@ -0,0 +1,204 @@ +#include "triton/Dialect/Triton/IR/Utility.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; + +Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, + Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = tt::SplatOp::create(rewriter, loc, maskType, pred); + } + if (currentMask) { + mask = arith::AndIOp::create(rewriter, loc, mask, currentMask); + } + return mask; +} + +static tt::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return tt::getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return tt::getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +tt::MakeTensorPtrOp tt::getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return tt::getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return tt::getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return tt::getMakeTensorPtrOp( + tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return tt::getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +Value tt::getLastInductionValue(OpBuilder &b, scf::ForOp loop) { + Location loc = loop.getLoc(); + // (ub - lb -1) // step * step + lb + Value diff = + arith::SubIOp::create(b, loc, loop.getUpperBound(), loop.getLowerBound()); + diff = arith::SubIOp::create( + b, loc, diff, + arith::ConstantOp::create(b, loc, b.getIntegerAttr(diff.getType(), 1))); + Value ceilStep = arith::MulIOp::create( + b, loc, arith::DivSIOp::create(b, loc, diff, loop.getStep()), + loop.getStep()); + return arith::AddIOp::create(b, loc, ceilStep, loop.getLowerBound()); +} + +bool tt::isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +bool tt::isHostSideDescriptor(Value v) { + auto arg = dyn_cast(v); + if (!arg) + return false; + auto funcOp = dyn_cast(arg.getOwner()->getParentOp()); + if (!funcOp) + return false; + return tt::isKernel(funcOp); +} + +unsigned tt::getBitwidth(RankedTensorType ty) { + auto isPtr = isa(ty.getElementType()); + return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u); +} + +std::optional tt::getBoundFromCmpOp(arith::CmpIOp cmpOp, + Value anchor) { + bool isSigned = true; + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + isSigned = false; + default: + break; + } + + bool anchorIsLhs = cmpOp.getLhs() == anchor; + auto maybeConstantIntValue = getConstantIntValue( + getAsOpFoldResult(anchorIsLhs ? cmpOp.getRhs() : cmpOp.getLhs())); + if (auto constValue = maybeConstantIntValue) { + unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType()); + assert(bitWidth > 0 && "expected non-zero bitwdith"); + APInt apVal = {bitWidth, static_cast(*constValue), isSigned}; + APInt min, max; + if (isSigned) { + min = APInt::getSignedMinValue(bitWidth); + if (llvm::isa_and_nonnull( + anchor.getDefiningOp())) { + min = APInt::getZero(bitWidth); + } else + min = APInt::getSignedMinValue(bitWidth); + max = APInt::getSignedMaxValue(bitWidth); + } else { + min = APInt::getMinValue(bitWidth); + max = APInt::getMaxValue(bitWidth); + } + + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::eq: + return mlir::ConstantIntRanges::constant(apVal); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::sge: { + // K >= apVal implies K ∈ [apVal, max] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(apVal, max, isSigned); + // apVal >= K implies K ∈ [min, apVal] + return mlir::ConstantIntRanges::range(min, apVal, isSigned); + } + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::sgt: { + // K > apVal implies K >= apVal + 1 implies K ∈ [apVal + 1, max] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned); + // apVal > K implies apVal - 1 >= K implies K ∈ [min, apVal - 1] + return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned); + } + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::sle: { + // K <= apVal implies K ∈ [min, apVal] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(min, apVal, isSigned); + // apVal <= K implies K ∈ [apVal, max] + return mlir::ConstantIntRanges::range(apVal, max, isSigned); + } + case arith::CmpIPredicate::ult: + case arith::CmpIPredicate::slt: { + // K < apVal implies K <= apVal -1 implies K ∈ [min, apVal - 1] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned); + // apVal < K implies apVal + 1 <= K implies K ∈ [apVal + 1, max] + return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned); + } + default: + emitRemark(cmpOp.getLoc(), "unsupported cmp predicate for assumption"); + return {}; + } + } + return {}; +} diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp new file mode 100644 index 0000000000..3928119409 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp @@ -0,0 +1,51 @@ +#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace { + +struct RewriteArithSelectOp : mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // Note we're replacing the select op with an if op because we are + // converting one value into many values. + auto newIf = mlir::scf::IfOp::create( + rewriter, op.getLoc(), mlir::TypeRange(adaptor.getTrueValue()), + op.getCondition(), true); + // We set the attributes from the op in case the op has any additional + // attributes + newIf->setAttrs(op->getAttrs()); + + { + mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newIf.thenBlock()); + mlir::scf::YieldOp::create(rewriter, op->getLoc(), + adaptor.getTrueValue()); + rewriter.setInsertionPointToStart(newIf.elseBlock()); + mlir::scf::YieldOp::create(rewriter, op->getLoc(), + adaptor.getFalseValue()); + } + + // Replace the old operation results + rewriter.replaceOpWithMultiple(op, {newIf->getResults()}); + + return mlir::success(); + } +}; + +} // namespace +namespace mlir::triton { + +void populateArithTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..8be846f589 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(TritonTransforms + Combine.cpp + LoopAwareCSE.cpp + LoopInvariantCodeMotion.cpp + LoopPeeling.cpp + LoopUnroll.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + RewriteTensorDescriptorToPointer.cpp + ArithTypeConversion.cpp + FunctionTypeConversion.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + MLIRTransforms + MLIRSCFToControlFlow + TritonIR +) diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 0000000000..1e9830c066 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,298 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/DiscardableAttributes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONCOMBINEOPS +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +bool isZero(Value val) { + return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())); +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto loadOp = trueValue.getDefiningOp(); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto splatOp = mask.getDefiningOp(); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = reduceOp.getOperand(0).getDefiningOp(); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp(); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp(); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp(); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp(); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = + SplatOp::create(rewriter, op->getLoc(), newAccType, + arith::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +// When reducing a 1D tensor the order of elements of the tensor doesn't matter. +// Therefore we can relax the reshape to allow it to re-order elements. +class CombineReshapeReducePatterns : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + if (reshapeOp.getAllowReorder()) + return failure(); + if (reshapeOp.getType().getRank() != 1) + return failure(); + for (Operation *user : reshapeOp->getUsers()) { + if (!isa(user)) + return failure(); + } + rewriter.modifyOpInPlace(reshapeOp, + [&]() { reshapeOp.setAllowReorder(true); }); + return success(); + } +}; + +class RankedReduceDescriptorLoads : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + auto loadDef = reshapeOp.getSrc().getDefiningOp(); + if (!loadDef || !loadDef->hasOneUse()) + return failure(); + int loadRank = loadDef.getType().getRank(); + int reshapeRank = reshapeOp.getType().getRank(); + if (!(reshapeRank < loadRank)) + return failure(); + ArrayRef loadShape = loadDef.getType().getShape(); + ArrayRef reshapeShape = reshapeOp.getType().getShape(); + for (int i = 0; i < loadRank - reshapeRank; ++i) { + // Only rank reduce unit dims. + if (loadShape[i] != 1) + return failure(); + } + if (loadShape.take_back(reshapeRank) != reshapeShape) + return failure(); + rewriter.modifyOpInPlace( + loadDef, [&]() { loadDef.getResult().setType(reshapeOp.getType()); }); + rewriter.replaceOp(reshapeOp, loadDef.getResult()); + return success(); + } +}; + +template +class CombineDotAddPattern : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(OpTy addOp, mlir::PatternRewriter &rewriter) const override { + auto dotOp = addOp.getRhs().template getDefiningOp(); + bool isDotLHS = false; + if (!dotOp) { + dotOp = addOp.getLhs().template getDefiningOp(); + if (!dotOp) { + return failure(); + } + isDotLHS = true; + } + if (!dotOp->hasOneUse()) { + return failure(); + } + if (!isZero(dotOp.getC())) + return failure(); + if constexpr (std::is_same_v) { + if (dotOp.getMaxNumImpreciseAcc() != 0) { + return failure(); + } + } + rewriter.modifyOpInPlace(dotOp, [&] { + dotOp.getCMutable().assign(isDotLHS ? addOp.getRhs() : addOp.getLhs()); + dotOp->moveBefore(addOp); + }); + rewriter.replaceAllUsesWith(addOp, dotOp.getResult()); + return success(); + } +}; + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +using CombineDotAddIPattern = CombineDotAddPattern; +using CombineDotAddFPattern = CombineDotAddPattern; + +} // anonymous namespace + +class CombineOpsPass : public impl::TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 0000000000..d8302f0ac1 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,23 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +defvar DefOverflow = ConstantEnumCase; + +def CopyDiscardableAttrs: NativeCodeCallVoid< + "$1.getOwner()->setDiscardableAttrs(triton::filterDiscardableAttrs($0.getOwner(), " + "{\"tt.divisibility\", \"tt.contiguity\", \"tt.constancy\", \"tt.pointee_type\"}))">; + +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp:$src (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp:$dest $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)], + [(CopyDiscardableAttrs $src, $dest)]>; + +#endif diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp new file mode 100644 index 0000000000..f3a454abea --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp @@ -0,0 +1,163 @@ +#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h" + +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir::triton { + +namespace { + +SmallVector flattenValues(ArrayRef values) { + SmallVector ret; + for (const auto &vs : values) { + llvm::append_range(ret, vs); + } + return ret; +} + +struct CallOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector resultReplacementGrouping; + llvm::SmallVector convertedResults; + + for (auto type : callOp->getResultTypes()) { + const auto oldNumFlattenedResults = convertedResults.size(); + if (failed(getTypeConverter()->convertTypes(type, convertedResults))) { + return failure(); + } + resultReplacementGrouping.push_back(convertedResults.size() - + oldNumFlattenedResults); + } + + auto newCallOp = + CallOp::create(rewriter, callOp->getLoc(), callOp.getCallee(), + convertedResults, flattenValues(adaptor.getOperands())); + // Preserve any additional attributes that may have been set on the op + newCallOp->setAttrs(callOp->getAttrs()); + + SmallVector replacements; + std::size_t offset = 0; + for (auto groupSize : resultReplacementGrouping) { + replacements.push_back(newCallOp->getResults().slice(offset, groupSize)); + offset += groupSize; + } + + rewriter.replaceOpWithMultiple(callOp, replacements); + return success(); + } +}; + +struct ReturnOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp returnOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReturnOp = ReturnOp::create(rewriter, returnOp->getLoc(), + flattenValues(adaptor.getOperands())); + // Preserve any additional attributes that may have been set on the op + newReturnOp->setAttrs(returnOp->getAttrs()); + + rewriter.replaceOp(returnOp, newReturnOp); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// FunctionOpInterfaceSignatureConversion +//===----------------------------------------------------------------------===// +// NOTE: Forked from mlir to support remapping argument attributes correctly in +// a one-to-many type conversion. + +SmallVector +convertFuncOpAttrs(FunctionOpInterface funcOp, + TypeConverter::SignatureConversion &sigConv, + FunctionType newType) { + if (newType.getNumInputs() == funcOp.getNumArguments()) { + return {}; + } + ArrayAttr allArgAttrs = funcOp.getAllArgAttrs(); + if (!allArgAttrs) + return {}; + + SmallVector newAttrs(newType.getNumInputs()); + for (auto i : llvm::seq(allArgAttrs.size())) { + auto mapping = sigConv.getInputMapping(i); + assert(mapping.has_value()); + auto outIdx = mapping->inputNo; + newAttrs[outIdx] = allArgAttrs[i]; + } + return newAttrs; +} + +LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, + const TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + FunctionType type = dyn_cast(funcOp.getFunctionType()); + if (!type) + return failure(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType); + + rewriter.modifyOpInPlace(funcOp, [&] { + funcOp.setType(newType); + if (!newArgAttrs.empty()) { + funcOp.setAllArgAttrs(newArgAttrs); + } + }); + + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. +struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { + FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, + MLIRContext *ctx, + const TypeConverter &converter, + PatternBenefit benefit = 1) + : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + FunctionOpInterface funcOp = cast(op); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; + +} // namespace + +void populateFunctionTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns) { + auto context = patterns.getContext(); + patterns.add( + triton::FuncOp::getOperationName(), context, converter); + patterns.add(converter, context); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp new file mode 100644 index 0000000000..ad9ca7f396 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp @@ -0,0 +1,178 @@ +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/EquivalenceClasses.h" + +using namespace mlir; + +namespace mlir::triton { +#define GEN_PASS_DEF_TRITONLOOPAWARECSE +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" +} // namespace mlir::triton + +namespace { +class ValueEquivalence { +public: + std::optional getKnownEquivalence(Value a, Value b) { + if (auto it = equalValues.find(normalizeKey(a, b)); it != equalValues.end()) + return it->second; + return std::nullopt; + } + void setKnownEquivalence(Value a, Value b, bool eq) { + equalValues.insert_or_assign(normalizeKey(a, b), eq); + } + +private: + // Commutatively query the equivalence of two values by sorting the key by + // pointer value. + std::pair normalizeKey(Value a, Value b) { + if ((uintptr_t)a.getAsOpaquePointer() < (uintptr_t)b.getAsOpaquePointer()) + return {a, b}; + return {b, a}; + } + + DenseMap, bool> equalValues; +}; + +struct LoopCSEDriver { + LoopCSEDriver(scf::ForOp loop) : loop(loop) {} + + bool areIterArgsEqual(int i, int j); + bool areEqualInLoop(Value a, Value b); + + scf::ForOp loop; + SmallVector> argStack; +}; +} // namespace + +bool LoopCSEDriver::areIterArgsEqual(int i, int j) { + if (i == j) + return true; + if (loop.getInitArgs()[i] != loop.getInitArgs()[j]) + return false; + if (llvm::is_contained(argStack, std::make_pair(i, j))) + return true; + + // First, assume the arguments are equal. This is how recursion is broken. + argStack.push_back({i, j}); + bool result = + areEqualInLoop(loop.getYieldedValues()[i], loop.getYieldedValues()[j]); + argStack.pop_back(); + return result; +} + +bool LoopCSEDriver::areEqualInLoop(Value a, Value b) { + // Check trivial case. + if (a == b) + return true; + if (a.getType() != b.getType()) + return false; + + Block *aBlock = a.getParentBlock(); + Block *bBlock = b.getParentBlock(); + // Values from outside the loop must have been equal. + if (aBlock != loop.getBody() || bBlock != loop.getBody()) { + return false; + } + // Both must be block arguments or not. + if (isa(a) != isa(b)) + return false; + // Both must be the inductor var or not. + if (a == loop.getInductionVar() || b == loop.getInductionVar()) + return false; + + if (auto aArg = dyn_cast(a)) { + auto bArg = cast(b); + bool result = + areIterArgsEqual(aArg.getArgNumber() - 1, bArg.getArgNumber() - 1); + return result; + } + + Operation *aDef = a.getDefiningOp(); + Operation *bDef = b.getDefiningOp(); + if (cast(a).getResultNumber() != + cast(b).getResultNumber()) + return false; + // For it to be known that the operation results have the same value, they + // must be side effect free. + if (!isMemoryEffectFree(aDef) || !isMemoryEffectFree(bDef)) + return false; + // Don't bother with operations with regions. + if (aDef->getNumRegions() || bDef->getNumRegions()) + return false; + + bool result = OperationEquivalence::isEquivalentTo( + aDef, bDef, + [&](Value a, Value b) { return success(areEqualInLoop(a, b)); }, + /*markEquivalent=*/nullptr, OperationEquivalence::IgnoreLocations); + return result; +} + +static void loopCSE(scf::ForOp loop) { + int numIterArgs = loop.getNumRegionIterArgs(); + // Group equivalent iter args together. + llvm::EquivalenceClasses equivalentArgs; + LoopCSEDriver driver(loop); + for (int i = 0; i != numIterArgs; ++i) { + for (int j = i + 1; j != numIterArgs; ++j) { + if (driver.areIterArgsEqual(i, j)) + equivalentArgs.unionSets(i, j); + } + } + + // For each equivalence class, replace all other args in the class with one. + for (auto it = equivalentArgs.begin(), end = equivalentArgs.end(); it != end; + ++it) { + if (!(*it)->isLeader()) + continue; + SmallVector eqArgs; + for (auto mIt = equivalentArgs.member_begin(**it); + mIt != equivalentArgs.member_end(); ++mIt) + eqArgs.push_back(*mIt); + assert(eqArgs.size() > 1); + // Sort the indices so the pass is deterministic. + llvm::sort(eqArgs); + BlockArgument unique = loop.getRegionIterArg(eqArgs.front()); + Value uniqueResult = loop.getResult(eqArgs.front()); + for (int j : llvm::drop_begin(eqArgs)) { + BlockArgument other = loop.getRegionIterArg(j); + other.replaceAllUsesWith(unique); + // Short-circuit the value. The canonicalizer will clean this up. Leftover + // subcomputations can now be removed by normal CSE. + (*loop.getYieldedValuesMutable())[j].set(other); + loop.getResult(j).replaceAllUsesWith(uniqueResult); + } + } +} + +namespace { +struct LoopAwareCSE + : public triton::impl::TritonLoopAwareCSEBase { + using TritonLoopAwareCSEBase::TritonLoopAwareCSEBase; + + void runOnOperation() override { + // LoopAwareCSE doesn't recursively CSE ops outside of loops, so run CSE + // first to make sure values from outside loops that are equivalent are made + // pointer equal. + IRRewriter rewriter(&getContext()); + auto &domInfo = getAnalysis(); + eliminateCommonSubExpressions(rewriter, domInfo, getOperation()); + + // CSE region iter args within loop bodies. + getOperation().walk(loopCSE); + + // Now that equivalent iter args have been made pointer equal, run CSE again + // to clean up the loop body. + eliminateCommonSubExpressions(rewriter, domInfo, getOperation()); + + // Run the `scf.for` canonicalizer to clean up the loops (short-circuited + // values, unused results, etc.). + RewritePatternSet patterns(&getContext()); + scf::ForOp::getCanonicalizationPatterns(patterns, &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp new file mode 100644 index 0000000000..a1de3bf845 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp @@ -0,0 +1,82 @@ +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONLOOPINVARIANTCODEMOTION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-licm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +class LoopInvariantCodeMotionPass + : public impl::TritonLoopInvariantCodeMotionBase< + LoopInvariantCodeMotionPass> { + + DenseMap isLoopMemoryEffectFreeOrOnlyRead; + + bool isMemoryEffectFreeOrOnlyRead(Operation *op) { + std::optional> effects = + getEffectsRecursively(op); + if (!effects) + return false; + return llvm::all_of(*effects, + [&](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()); + }); + } + + void runOnOperation() override { + // Walk through all loops in a function in innermost-loop-first order. + // This way, we first LICM from the inner loop, and place the ops in the + // outer loop, which in turn can be further LICM'ed. + getOperation()->walk([&](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode( + loopLike.getLoopRegions(), + // isDefinedOutsideOfRegion + [&](Value value, Region *region) { + return loopLike.isDefinedOutsideOfLoop(value); + }, + // shouldMoveOutOfRegion + [&](Operation *op, Region *region) { + if (!isa(op)) + return isSpeculatable(op) && isMemoryEffectFree(op); + if (!isLoopMemoryEffectFreeOrOnlyRead.contains(loopLike)) + isLoopMemoryEffectFreeOrOnlyRead[loopLike] = + isMemoryEffectFreeOrOnlyRead(loopLike); + return isMemoryEffectFreeOrOnlyRead(op) && + isLoopMemoryEffectFreeOrOnlyRead[loopLike]; + }, + // moveOutOfRegion + [&](Operation *op, Region *) { + // Create the new mask for load op. + if (auto loadOp = dyn_cast(op)) { + IRRewriter rewriter(loopLike); + Location loc = loopLike->getLoc(); + Value cond; + if (auto forOp = dyn_cast(loopLike.getOperation())) { + cond = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, + forOp.getLowerBound(), forOp.getUpperBound()); + } else if (auto whileOp = + dyn_cast(loopLike.getOperation())) { + // TODO: Support Load Op hoisting for while loop. + return; + } else { + return; + } + Value newMask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), cond); + loadOp.getMaskMutable().assign(newMask); + } + loopLike.moveOutOfLoop(op); + }); + }); + } +}; + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopPeeling.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopPeeling.cpp new file mode 100644 index 0000000000..ed887bfee0 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopPeeling.cpp @@ -0,0 +1,67 @@ +#include "triton/Dialect/Triton/Transforms/LoopPeeling.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +using namespace mlir; + +namespace mlir { +namespace triton { + +void peelLoopEpilogue( + scf::ForOp forOp, + function_ref + processPeeledOp) { + SmallVector loopBodyOps; + IRRewriter rewriter(forOp); + Location loc = forOp.getLoc(); + Type type = forOp.getStep().getType(); + + // Fetch loop bounds and step + Value lowerBound = forOp.getLowerBound(); + Value upperBound = forOp.getUpperBound(); + Value step = forOp.getStep(); + Value newUpperBound = arith::SubIOp::create(rewriter, loc, upperBound, step); + + rewriter.setInsertionPointAfter(forOp); + Value lastIV = getLastInductionValue(rewriter, forOp); + + auto cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + lowerBound, upperBound); + + // Create an if op to execute the peeled iteration + IRMapping map; + map.map(forOp.getRegionIterArgs(), forOp.getResults()); + map.map(forOp.getInductionVar(), lastIV); + auto ifOp = scf::IfOp::create(rewriter, loc, forOp.getResultTypes(), cond); + forOp.getBodyRegion().cloneInto(&ifOp.getThenRegion(), map); + auto newElseBlock = rewriter.createBlock(&ifOp.getElseRegion()); + rewriter.setInsertionPointToStart(newElseBlock); + scf::YieldOp::create(rewriter, loc, forOp.getResults()); + + forOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) { + return !ifOp->isAncestor(operand.getOwner()); + }); + + forOp.getUpperBoundMutable().assign(newUpperBound); + + if (processPeeledOp) { + for (auto &op : + llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { + Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/false); + if (newOp && newOp != &op) { + op.replaceAllUsesWith(newOp); + } + } + for (auto &op : llvm::make_early_inc_range( + ifOp.getThenRegion().front().without_terminator())) { + Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/true); + if (newOp && newOp != &op) { + op.replaceAllUsesWith(newOp); + } + } + } +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopUnroll.cpp new file mode 100644 index 0000000000..294dff873e --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -0,0 +1,62 @@ +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONLOOPUNROLL +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-loop-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +class LoopUnrollPass : public impl::TritonLoopUnrollBase { + + int getUnrollFactorOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise set the + // factor to 1 to suppress the unrolling. + if (auto factor = + forOp->getAttrOfType(loopUnrollFactorAttrName)) + return factor.getInt(); + return 1; + } + + const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + const char *pipelineStagesAttrName = "tt.num_stages"; + +public: + void runOnOperation() override { + LDBG("Loop unroll pass"); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with unroll factor <= 1. + if (getUnrollFactorOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + auto ctx = getOperation()->getContext(); + for (auto loop : loops) { + auto unrollFactor = getUnrollFactorOrDefault(loop); + loop->removeAttr(loopUnrollFactorAttrName); + LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); + auto resultLoops = loopUnrollByFactor(loop, unrollFactor); + // Do not pipeline the epilog loop. + if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) { + (*resultLoops->epilogueLoopOp) + ->setAttr(pipelineStagesAttrName, + mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1)); + } + } + } +}; + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 0000000000..bdb8e527f9 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,230 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + + if (op->getNumOperands() <= 0) + return failure(); + + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = SplatOp::create(rewriter, loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + return success(); + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + if (!seenBroadcast) + return failure(); + + auto loc = op->getLoc(); + + // Find broadcast op + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto bcSrcShape = srcTy.getShape(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = srcTy.clone(bcSrcShape, elemTy); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = SplatOp::create(rewriter, loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + arith::ConstantOp::create(rewriter, loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back(srcTy.clone(bcSrcShape, elemTy)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = BroadcastOp::create(rewriter, loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + return success(); + } +}; + +} // namespace + +class ReorderBroadcastPass + : public impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + BroadcastOp::getCanonicalizationPatterns(patterns, context); + ExpandDimsOp::getCanonicalizationPatterns(patterns, context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp new file mode 100644 index 0000000000..3a671b4095 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp @@ -0,0 +1,616 @@ +#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h" +#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include + +#include + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREWRITETENSORDESCRIPTORTOPOINTER +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +bool hasATensorDescriptorType(mlir::TypeRange types) { + return llvm::any_of(types, [](mlir::Type t) { + return llvm::isa(t); + }); +} + +using namespace mlir; + +/** + * @brief Filter out operand segment sizes from the list of attributes since + * this attribute is operation specific and shouldn't be set arbitrarily. + */ +mlir::SmallVector +filterSegmentSizes(mlir::ArrayRef attrs) { + mlir::SmallVector ret; + llvm::copy_if(attrs, std::back_inserter(ret), [](const NamedAttribute &attr) { + auto attrName = attr.getName().getValue(); + return attrName != "operandSegmentSizes"; + }); + return ret; +} + +struct Descriptor { + Value base; + ValueRange shape; + ValueRange strides; + Value paddingOption; +}; + +Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) { + int rank = type.getBlockType().getRank(); + assert(pack.size() == 1 + 2 * static_cast(rank) + 1 && + "Expected tensor descriptors to consist of a pointer, " + "followed by 'rank' shape values and 'rank' stride values, " + "followed by a padding option value."); + + Descriptor res; + res.base = pack[0]; + res.shape = pack.slice(1, rank); + res.strides = pack.slice(1 + rank, rank); + res.paddingOption = pack[1 + 2 * rank]; + return res; +} + +Value expandOffsets(OpBuilder &builder, Location loc, + ArrayRef blockShape, Value offsets, unsigned dim) { + Value expandedResult = offsets; + for (size_t j = 0; j < blockShape.size(); ++j) { + if (j == dim) { + continue; + } + expandedResult = + triton::ExpandDimsOp::create(builder, loc, expandedResult, j); + } + + return expandedResult; +} + +Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, + Value offset, unsigned dim) { + auto offsetType = mlir::dyn_cast(offset.getType()); + assert(offsetType && "expected integer offset type"); + unsigned width = offsetType.getWidth(); + auto indexElemType = builder.getIntegerType(width); + // Add range. + auto indexI32RowType = + RankedTensorType::get({blockShape[dim]}, builder.getI32Type()); + auto indexRowType = RankedTensorType::get({blockShape[dim]}, indexElemType); + Value splatOffset = + triton::SplatOp::create(builder, loc, indexRowType, offset); + Value range = triton::MakeRangeOp::create(builder, loc, indexI32RowType, 0, + blockShape[dim]); + Value typedRange = range; + if (width > 32) { + typedRange = arith::ExtSIOp::create(builder, loc, indexRowType, range); + } else if (width < 32) { + typedRange = arith::TruncIOp::create(builder, loc, indexRowType, range); + } + + Value offsets = arith::AddIOp::create(builder, loc, splatOffset, typedRange); + return expandOffsets(builder, loc, blockShape, offsets, dim); +} + +Value generatePtrFromOffsetRanges(OpBuilder &builder, Location loc, + ArrayRef blockShape, + Descriptor &desc, ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + auto ptrType = cast(desc.base.getType()); + auto ptrTensorType = RankedTensorType::get(blockShape, ptrType); + + // Generate offsets per dimension + Value ptr = triton::SplatOp::create(builder, loc, ptrTensorType, desc.base); + for (unsigned i = 0; i < blockShape.size(); ++i) { + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = triton::SplatOp::create( + builder, loc, offsets[i].getType(), desc.strides[i]); + Value offsetWithStride = + arith::MulIOp::create(builder, loc, offsets[i], splatStride); + auto indexTensorType = RankedTensorType::get( + blockShape, + mlir::cast(offsetWithStride.getType()).getElementType()); + Value broadcasted = triton::BroadcastOp::create( + builder, loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = + triton::AddPtrOp::create(builder, loc, ptrTensorType, ptr, broadcasted); + } + + return ptr; +} + +Value generatePtr(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, Descriptor &desc, + ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + SmallVector offsetRanges; + for (unsigned i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = + getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i); + offsetRanges.push_back(offsetWithRange); + } + + return generatePtrFromOffsetRanges(builder, loc, blockShape, desc, + offsetRanges); +} + +Value generateMaskFromOffsetRanges(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, + Descriptor &desc, ValueRange offsetRanges) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsetRanges.size()); + + // Generate mask per dimension + auto maskTensorType = RankedTensorType::get(blockShape, builder.getI1Type()); + Value mask; + for (std::size_t i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = offsetRanges[i]; + auto offsetElemType = mlir::dyn_cast( + mlir::cast(offsetWithRange.getType()).getElementType()); + assert(offsetElemType && "expected integer offset tensor type"); + auto offsetWidth = offsetElemType.getWidth(); + + // Compare with lower bound + Value lowerBound = + mlir::arith::ConstantIntOp::create(builder, loc, 0, offsetWidth); + Value splatLowerBound = triton::SplatOp::create( + builder, loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge, + offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value upperBound = desc.shape[i]; + if (upperBound.getType() != offsetElemType) { + auto upperIntTy = mlir::dyn_cast(upperBound.getType()); + assert(upperIntTy && "expected integer shape type"); + if (upperIntTy.getWidth() > offsetWidth) { + upperBound = + arith::TruncIOp::create(builder, loc, offsetElemType, upperBound); + } else if (upperIntTy.getWidth() < offsetWidth) { + upperBound = + arith::ExtSIOp::create(builder, loc, offsetElemType, upperBound); + } + } + Value splatUpperBound = triton::SplatOp::create( + builder, loc, offsetWithRange.getType(), upperBound); + Value cmpUpper = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, + offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = arith::AndIOp::create(builder, loc, cmpLower, cmpUpper); + Value broadcasted = + triton::BroadcastOp::create(builder, loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = arith::AndIOp::create(builder, loc, mask, broadcasted); + } + } + + return mask; +} + +Value generateMask(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, Descriptor &desc, + ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + SmallVector offsetRanges; + for (unsigned i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = + getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i); + offsetRanges.push_back(offsetWithRange); + } + + return generateMaskFromOffsetRanges(builder, loc, blockShape, desc, + offsetRanges); +} + +Value generateOther(OpBuilder &builder, Location loc, Type scalarTy, + ArrayRef blockShape, + Value paddingOption = nullptr) { + auto blockTy = RankedTensorType::get(blockShape, scalarTy); + if (paddingOption && mlir::isa(scalarTy)) { + auto floatTy = mlir::cast(scalarTy); + auto nan = llvm::APFloat::getNaN(floatTy.getFloatSemantics()); + auto nanValue = arith::ConstantOp::create( + builder, loc, + SplatElementsAttr::get(blockTy, builder.getFloatAttr(floatTy, nan))); + auto zeroValue = arith::ConstantOp::create( + builder, loc, + SplatElementsAttr::get(blockTy, builder.getZeroAttr(floatTy))); + return mlir::arith::SelectOp::create(builder, loc, paddingOption, nanValue, + zeroValue); + } else { + auto attr = builder.getZeroAttr(blockTy); + return arith::ConstantOp::create(builder, loc, attr); + } +} + +Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy, + Value paddingOption = nullptr) { + auto blockTy = descTy.getSignlessBlockType(); + return generateOther(builder, loc, blockTy.getElementType(), + blockTy.getShape(), paddingOption); +} + +SmallVector castToI64(OpBuilder &builder, + mlir::ValueRange values) { + auto i64Type = builder.getI64Type(); + return llvm::map_to_vector(values, [&](mlir::Value v) { + return builder.createOrFold(v.getLoc(), i64Type, v); + }); +} + +SmallVector castToI32(OpBuilder &builder, + mlir::ValueRange values) { + auto i32Type = builder.getI32Type(); + return llvm::map_to_vector(values, [&](mlir::Value v) -> mlir::Value { + auto vType = mlir::dyn_cast(v.getType()); + if (!vType) + return v; + if (vType == i32Type) + return v; + if (vType.getWidth() > 32) + return builder.createOrFold(v.getLoc(), i32Type, v); + return builder.createOrFold(v.getLoc(), i32Type, v); + }); +} + +struct RewriteMakeTensorDesc : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector ptrShapeStridesPaddingOption; + llvm::append_values(ptrShapeStridesPaddingOption, adaptor.getBase()); + llvm::append_range(ptrShapeStridesPaddingOption, + castToI64(rewriter, adaptor.getShape())); + llvm::append_range(ptrShapeStridesPaddingOption, adaptor.getStrides()); + auto paddingOption = mlir::arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(adaptor.getPadding() == + triton::PaddingOption::PAD_NAN)); + llvm::append_values(ptrShapeStridesPaddingOption, paddingOption); + rewriter.replaceOpWithMultiple(op, {ptrShapeStridesPaddingOption}); + return mlir::success(); + } +}; + +struct RewriteLoadPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const auto blockShape = op.getDesc().getType().getBlockType().getShape(); + auto descTy = op.getDesc().getType(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + bool useI32Path = + blockShape.size() == 2 && + llvm::all_of(op.getIndices().getTypes(), [](mlir::Type ty) { + auto intTy = dyn_cast(ty); + return intTy && intTy.getWidth() <= 32; + }); + SmallVector offsetsI32; + SmallVector shapeI32; + SmallVector stridesI32; + Descriptor descI32; + if (useI32Path) { + offsetsI32 = castToI32(rewriter, op.getIndices()); + shapeI32 = castToI32(rewriter, desc.shape); + stridesI32 = castToI32(rewriter, desc.strides); + descI32 = Descriptor{desc.base, shapeI32, stridesI32, desc.paddingOption}; + } + auto other = generateOther(rewriter, loc, descTy, desc.paddingOption); + auto newLoad = rewriter.replaceOpWithNewOp( + op, + useI32Path ? generatePtr(rewriter, loc, blockShape, descI32, offsetsI32) + : generatePtr(rewriter, loc, blockShape, desc, offsets), + useI32Path + ? generateMask(rewriter, loc, blockShape, descI32, offsetsI32) + : generateMask(rewriter, loc, blockShape, desc, offsets), + other, triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, + false); + newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +struct RewriteStorePattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorStoreOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = descTy.getBlockType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + bool useI32Path = + blockShape.size() == 2 && + llvm::all_of(op.getIndices().getTypes(), [](mlir::Type ty) { + auto intTy = dyn_cast(ty); + return intTy && intTy.getWidth() <= 32; + }); + SmallVector offsetsI32; + SmallVector shapeI32; + SmallVector stridesI32; + Descriptor descI32; + if (useI32Path) { + offsetsI32 = castToI32(rewriter, op.getIndices()); + shapeI32 = castToI32(rewriter, desc.shape); + stridesI32 = castToI32(rewriter, desc.strides); + descI32 = Descriptor{desc.base, shapeI32, stridesI32, desc.paddingOption}; + } + + auto newStore = rewriter.replaceOpWithNewOp( + op, + useI32Path ? generatePtr(rewriter, loc, blockShape, descI32, offsetsI32) + : generatePtr(rewriter, loc, blockShape, desc, offsets), + op.getSrc(), + useI32Path + ? generateMask(rewriter, loc, blockShape, descI32, offsetsI32) + : generateMask(rewriter, loc, blockShape, desc, offsets), + triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL); + newStore->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +std::pair +generateGatherScatterPtrMask(OpBuilder &builder, Location loc, + ArrayRef blockShape, Descriptor &desc, + Value xOffsets, Value yOffset) { + Value xOffsetRange = + expandOffsets(builder, loc, blockShape, xOffsets, /*dim=*/0); + yOffset = castToI64(builder, {yOffset})[0]; + auto xOffsetI64Ty = RankedTensorType::get( + cast(xOffsetRange.getType()).getShape(), + yOffset.getType()); + xOffsetRange = + arith::ExtSIOp::create(builder, loc, xOffsetI64Ty, xOffsetRange); + auto yOffsetRange = + getExpandedOffsetWithRange(builder, loc, blockShape, yOffset, /*dim=*/1); + auto ptr = generatePtrFromOffsetRanges(builder, loc, blockShape, desc, + {xOffsetRange, yOffsetRange}); + auto mask = generateMaskFromOffsetRanges(builder, loc, blockShape, desc, + {xOffsetRange, yOffsetRange}); + return {ptr, mask}; +} + +struct RewriteGatherPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = op.getResult().getType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto [ptr, mask] = generateGatherScatterPtrMask( + rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset()); + auto other = generateOther(rewriter, loc, + descTy.getSignlessBlockType().getElementType(), + blockShape, desc.paddingOption); + auto newLoad = rewriter.replaceOpWithNewOp( + op, ptr, mask, other, triton::CacheModifier::NONE, + triton::EvictionPolicy::NORMAL, false); + newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +struct RewriteScatterPattern + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = op.getSrc().getType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto [ptr, mask] = generateGatherScatterPtrMask( + rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset()); + auto newStore = rewriter.replaceOpWithNewOp( + op, ptr, op.getSrc(), mask, triton::CacheModifier::NONE, + triton::EvictionPolicy::NORMAL); + newStore->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +std::optional translateReduceKind(DescriptorReduceKind kind, + TensorDescType ty) { + auto scalarTy = ty.getBlockType().getElementType(); + switch (kind) { + case DescriptorReduceKind::ADD: + return scalarTy.isInteger() ? RMWOp::ADD : RMWOp::FADD; + case DescriptorReduceKind::MIN: + if (scalarTy.isUnsignedInteger()) { + return RMWOp::UMIN; + } else if (scalarTy.isSignedInteger()) { + return RMWOp::MIN; + } + return {}; + case DescriptorReduceKind::MAX: + if (scalarTy.isUnsignedInteger()) { + return RMWOp::UMAX; + } else if (scalarTy.isSignedInteger()) { + return RMWOp::MAX; + } + return {}; + case DescriptorReduceKind::AND: + return RMWOp::AND; + case DescriptorReduceKind::OR: + return RMWOp::OR; + case DescriptorReduceKind::XOR: + return RMWOp::XOR; + default: + break; + } + return {}; +} + +struct RewriteReducePattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorReduceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = descTy.getBlockType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + auto rmwOp = translateReduceKind(op.getKind(), descTy); + if (!rmwOp) { + std::string msgstring; + llvm::raw_string_ostream msg(msgstring); + msg << "Cannot fallback on descriptor atomic op, unsupported for type " + << descTy.getBlockType().getElementType(); + return op->emitError(msgstring); + } + + triton::AtomicRMWOp::create( + rewriter, loc, descTy.getSignlessBlockType(), *rmwOp, + generatePtr(rewriter, loc, blockShape, desc, offsets), op.getSrc(), + generateMask(rewriter, loc, blockShape, desc, offsets), + MemSemantic::RELEASE, MemSyncScope::GPU); + op.erase(); + return success(); + } +}; + +/** + * @brief This implements the pass for converting triton tensor descriptor + * loads/stores into indexed loads/stores. + * + * The key idea is that each tensor descriptor can be broken down into multiple + * values. Suppose we have a tensor pointer with rank r, we can cast that tensor + * descriptor value to and from 1+2r values: a tensor pointer value and two i32 + * value for each dimension representing the dynamic shape and strides. + * + * As in normal conversion patterns, individual operations can be converted + * using casted tensor descriptors and offsets and casting the results back to + * tensor pointers. + * + * We have special handling for TMA loads/stores and the make tensor descriptor + * op. + * + * @note Why use the conversion pattern rewriter? In most cases the defining + * operation of a tensor descriptor will be a make tensor descriptor op. + * However, this isn't always true - for example, if the tensor descriptor is a + * function argument or is in a conditional statement, we need better tracking + * of the pointer, shape, and strides. + */ +class TritonRewriteTensorDescriptorToPointerPass + : public impl::TritonRewriteTensorDescriptorToPointerBase< + TritonRewriteTensorDescriptorToPointerPass> { + void runOnOperation() override { + auto op = getOperation(); + + mlir::ConversionTarget target(getContext()); + target.addDynamicallyLegalDialect( + [](mlir::Operation *op) { + return !hasATensorDescriptorType(op->getOperandTypes()) && + !hasATensorDescriptorType(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([](triton::FuncOp funcOp) { + return !hasATensorDescriptorType(funcOp.getFunctionType().getInputs()) && + !hasATensorDescriptorType(funcOp.getFunctionType().getResults()); + }); + + mlir::TypeConverter converter; + + converter.addConversion([](mlir::Type t) { + // Most types don't require any conversion + return t; + }); + converter.addConversion([](mlir::triton::TensorDescType t, + llvm::SmallVectorImpl &out) { + // We convert a tensor descriptor into an pointer, and a shape and stride + // for each dimension, and padding option. i.e., we create 1+2*rank+1 + // values. Note that tensor descriptors may be signed/unsigned integers + // whereas pointers should always be signless. + auto tensorType = t.getSignlessBlockType(); + out.push_back(triton::getPointerType(tensorType.getElementType())); + out.insert(out.end(), 2 * tensorType.getRank(), + mlir::IntegerType::get(t.getContext(), 64)); + out.push_back(mlir::IntegerType::get(t.getContext(), 1)); + return mlir::success(); + }); + + mlir::RewritePatternSet patterns(op->getContext()); + + // Populate conversion patterns to handle loops, function calls, and arith + // ops. + triton::populateFunctionTypeConversions(converter, patterns); + mlir::scf::populateSCFStructuralTypeConversions(converter, patterns); + triton::populateArithTypeConversions(converter, patterns); + + patterns + .add( + converter, &getContext()); + + ConversionConfig config; + config.buildMaterializations = false; + + if (mlir::failed(mlir::applyPartialConversion( + op, target, std::move(patterns), config))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 0000000000..7c85ccb999 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,566 @@ +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREWRITETENSORPOINTER +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo &operator=(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + triton::SplatOp::create(builder, loc, indexRowType, offsets[i]); + Value range = triton::MakeRangeOp::create(builder, loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = arith::ExtSIOp::create(builder, loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + arith::AddIOp::create(builder, loc, splatOffset, i64Range); + for (size_t j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + triton::ExpandDimsOp::create(builder, loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = triton::SplatOp::create(builder, loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = triton::SplatOp::create( + builder, loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + arith::MulIOp::create(builder, loc, offsetWithRange, splatStride); + Value broadcasted = triton::BroadcastOp::create( + builder, loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = triton::AddPtrOp::create(builder, loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = mlir::arith::ConstantIntOp::create( + builder, loc, builder.getI64Type(), 0); + Value splatLowerBound = triton::SplatOp::create( + builder, loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge, + offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = triton::SplatOp::create( + builder, loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, + offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = arith::AndIOp::create(builder, loc, cmpLower, cmpUpper); + Value broadcasted = + triton::BroadcastOp::create(builder, loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = arith::AndIOp::create(builder, loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = builder.getZeroAttr(elementType); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = arith::ConstantOp::create(builder, loc, attr); + return triton::SplatOp::create(builder, loc, otherTensorType, constant); + } +}; + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public impl::TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static void generateNewOperands(SmallVector &oldOperands, + unsigned index, ArrayRef newValues) { + size_t size = oldOperands.size(); + assert(index < size); + SmallVector operands = oldOperands; + oldOperands.reserve(size - 1 + newValues.size()); + oldOperands.clear(); + if (index != 0) { + oldOperands.append(operands.begin(), operands.begin() + index); + } + oldOperands.append(newValues.begin(), newValues.end()); + if (index != size - 1) { + oldOperands.append(operands.begin() + index + 1, operands.end()); + } + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = arith::ExtSIOp::create(builder, op.getLoc(), + builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (size_t i = 0; i < info.length(); ++i) { + Value i64Offset = arith::ExtSIOp::create( + builder, op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = arith::AddIOp::create(builder, op.getLoc(), + info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = triton::LoadOp::create( + builder, loadOp.getLoc(), newPtr, newMask, newOther, + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + triton::StoreOp::create(builder, storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, storeOp.getCache(), + storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + const auto &info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = scf::IfOp::create(builder, op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + auto opResults = op.getResults(); + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx)); + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewritten information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = + scf::ForOp::create(builder, op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewritten info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + builder.clone(opInFor, mapping); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewritten info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (isa(op->getDialect())) { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : llvm::make_early_inc_range(block)) { + if (auto newOp = rewriteOp(&nestedOp, eraser)) { + visitOperation(newOp, eraser); + } + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..782b66b686 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +add_triton_library(TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonGPUCGAAttrIncGen + TritonGPUTableGen + TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen + TritonGPUOpInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonTools +) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 0000000000..55e9c0b553 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,4561 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include +#include + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/MathExtras.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpInterfaces.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +static SmallVector +basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName, + size_t rank, bool skipBroadcast = true); + +// Utility +namespace mlir { +namespace triton { +namespace gpu { + +LinearEncodingAttr TritonGPUDialect::toLinearEncoding(ArrayRef shape, + Attribute layout) { + // LinearEncoding is a DistributedLayout + std::vector allocationShape; + CacheKey key{std::vector(shape.begin(), shape.end()), layout}; + if (auto result = leCache.get(key)) { + return *result; + } + auto linearLayout = toLinearLayout(shape, layout); + auto linearEncoding = + LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout)); + leCache.set(key, linearEncoding); + return linearEncoding; +} + +LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout, + ArrayRef shape) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearEncoding(shape, + layout); +} + +LinearEncodingAttr toLinearEncoding(RankedTensorType type) { + auto *ctx = type.getContext(); + return ctx->getLoadedDialect()->toLinearEncoding( + type.getShape(), type.getEncoding()); +} + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape) { + return toLinearEncoding(cast(layout), shape) + .getTotalElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(cast(layout), shape) + .getElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + if (auto vecType = dyn_cast(type)) { + return SmallVector(vecType.getShape().begin(), + vecType.getShape().end()); + } + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + if (auto vecType = dyn_cast(type)) + return vecType.getNumElements(); + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), + tensorType.getShape()); +} + +SmallVector getThreadsPerWarp(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(cast(layout), shape) + .getThreadsPerWarp(); +} + +SmallVector getWarpsPerCTA(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(cast(layout), shape) + .getWarpsPerCTA(); +} + +SmallVector getContigPerThread(RankedTensorType type) { + return toLinearEncoding(type).getContigPerThread(); +} + +bool isExpensiveView(Type srcType, Type dstType) { + auto tensorSrcType = cast(srcType); + auto tensorDstType = cast(dstType); + auto llSrc = toLinearLayout(tensorSrcType); + auto llDst = toLinearLayout(tensorDstType); + // In case there are replicated value we need to make sure the new and old + // layout have matching masks. + for (auto [srcMask, dstMask] : + llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) { + assert(srcMask.first == dstMask.first); + if (srcMask.second != dstMask.second) + return true; + } + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by get.*Order methods of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + SmallVector order(rank); + if (rank < 2) { + return order; + } + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kContig) { + // kContig: if true, the matrix is fastest-running on k, + // otherwise it is on m (resp. n) + // opIdx=0: [*batch, m, k] + // opIdx=1: [*batch, k, n] + assert(opIdx == 0 || opIdx == 1); + auto rowMajor = bool(opIdx) != kContig; + return getMatrixOrder(rank, rowMajor); +} + +SmallVector getRepOrder(RankedTensorType type) { + auto layout = type.getEncoding(); + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; +} + +// Legacy impl for now +// This one's not terribly bad as we don't broadcast ShareEncodings +SmallVector getOrder(SharedEncodingTrait layout, + ArrayRef shape) { + if (auto swizzledLayout = dyn_cast(layout)) { + return llvm::to_vector(swizzledLayout.getOrder()); + } + if (auto paddedEnc = dyn_cast(layout)) { + return paddedEnc.getOrder(); + } + if (auto linearEnc = dyn_cast(layout)) { + return linearEnc.getOrder(); + } + if (auto sharedLayout = dyn_cast(layout)) { + if (shape.size() == 1) { + return {0}; + } + return getMatrixOrder(shape.size(), !sharedLayout.getTransposed()); + } + if (auto sharedLayout = dyn_cast(layout)) { + return llvm::to_vector(sharedLayout.getOrder()); + } + llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType"); + return {}; +} + +SmallVector getOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getOrder(); +} + +SmallVector getOrderForMemory(DistributedEncodingTrait layout, + ArrayRef shape) { + auto linear = toLinearEncoding(layout, shape); + auto order = linear.getOrder(); + auto threadOrder = linear.getThreadOrder(); + if (order == threadOrder) { + return order; + } + // Heuristic: + // If the element contiguity does not align with the thread order + // because the thread order dimension has contiguity of 1---meaning that + // the order position of this dimension is irrelevant---we prefer + // to use the thread order for the memory layout + auto contig = linear.getElemsPerThread(shape); + if (contig[threadOrder[0]] == 1) { + return threadOrder; + } + return order; +} + +SmallVector getThreadOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getThreadOrder(); +} + +SmallVector getWarpOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getWarpOrder(); +} + +CGAEncodingAttr getCGALayout(Attribute layout) { + if (auto ttgLayout = mlir::dyn_cast(layout)) + return ttgLayout.getCGALayout(); + llvm::report_fatal_error("Unimplemented usage of getCGALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + if (auto ttgLayout = mlir::dyn_cast(layout)) + return ttgLayout.getCGALayout().getCTAsPerCGA(); + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + return ttgLayout.getCGALayout().getCTASplitNum(); + } else if (auto tmemLayout = + mlir::dyn_cast( + layout)) { + res.resize(2); + res[0] = tmemLayout.getCTASplitM(); + res[1] = tmemLayout.getCTASplitN(); + } else if (auto tmemScaleLayout = mlir::dyn_cast< + triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(layout)) { + res.resize(2); + res[0] = tmemScaleLayout.getCTASplitM(); + res[1] = tmemScaleLayout.getCTASplitN(); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + res = ttgLayout.getCGALayout().getCTAOrder(); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + auto splitNum = llvm::to_vector(CTASplitNum); + if (splitNum.size() <= rank) { // pipelining + splitNum.insert(splitNum.begin(), rank - splitNum.size(), 1); + } else { // memory slicing + splitNum = + llvm::to_vector(llvm::drop_begin(splitNum, splitNum.size() - rank)); + } + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + shapePerCTA[i] = shape[i] / std::min(shape[i], splitNum[i]); + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + return getShapePerCTA(getCTASplitNum(layout), shape); +} + +SmallVector getAllocationShapePerCTA(Attribute layout, + ArrayRef shapeLogical) { + SmallVector shape(shapeLogical); + if (auto sharedMMALayout = dyn_cast(layout)) { + if (sharedMMALayout.getFp4Padded()) { + auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0]; + shape[packedAxis] *= 2; + } + } + return getShapePerCTA(layout, shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +SmallVector getAllocationShapePerCTA(Type type) { + auto tensorType = cast(type); + return getAllocationShapePerCTA(tensorType.getEncoding(), + tensorType.getShape()); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +SmallVector orderPerDimImpl(const LinearLayout &ll, + StringAttr dimName, + ArrayRef defaultOrder) { + assert(ll.getBases().contains(dimName)); + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return order.takeVector(); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape); + return newTotalElemsPerThread < totalElemsPerThread; +} + +static LogicalResult +verifyLayoutOrder(function_ref emitError, + ArrayRef order) { + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +LogicalResult +CGAEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + if (linearLayout.getNumInDims() != 1) { + return emitError() << "CGA encoding must have exactly one input dimension " + "named 'block'."; + } + auto dim = *linearLayout.getInDimNames().begin(); + auto ctx = dim.getContext(); + if (dim != StringAttr::get(ctx, "block")) { + return emitError() << "CGA encoding must have exactly one input dimension " + "named 'block'."; + } + + auto outDimNames = linearLayout.getOutDimNames(); + auto expected = standardOutDimNames(ctx, linearLayout.getNumOutDims()); + if (!llvm::equal(outDimNames, expected)) { + return emitError() << "CGA encoding output dims must be [dim0, dim1, ...], " + "but got [" + << outDimNames << "]."; + } + + return success(); +} + +CGAEncodingAttr CGAEncodingAttr::get1CTALayout(MLIRContext *ctx, int rank) { + auto kBlock = StringAttr::get(ctx, "block"); + LinearLayout::BasesT bases; + bases[kBlock] = {}; + auto dims = standardOutDimNames(ctx, rank); + return get(ctx, LinearLayout(std::move(bases), dims)); +} + +CGAEncodingAttr CGAEncodingAttr::get1DLayout(MLIRContext *ctx, int numCTAs) { + auto kBlock = StringAttr::get(ctx, "block"); + auto dims = standardOutDimNames(ctx, /*rank=*/1); + auto layout = LinearLayout::identity1D(numCTAs, kBlock, dims[0]); + return get(ctx, std::move(layout)); +} + +CGAEncodingAttr CGAEncodingAttr::fromSplitParams(MLIRContext *ctx, + ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, + ArrayRef CTAOrder) { + int rank = CTAOrder.size(); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + LinearLayout layout = LinearLayout::empty(); + SmallVector splitNums(CTASplitNum.begin(), CTASplitNum.end()); + SmallVector ctas(CTAsPerCGA.begin(), CTAsPerCGA.end()); + + for (int i = 0; i < rank; ++i) { + int dim = CTAOrder[i]; + unsigned split = splitNums[dim]; + unsigned total = ctas[dim]; + assert(total % split == 0 && "invalid CGA encoding parameters"); + layout *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(total / split, kBlock, outDimNames[dim]); + } + + layout = layout.transposeOuts(outDimNames); + return CGAEncodingAttr::get(ctx, std::move(layout)); +} + +SmallVector CGAEncodingAttr::getCTAsPerCGA() const { + const auto &ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"), + rank, /*skipBroadcast=*/false); +} + +SmallVector CGAEncodingAttr::getCTASplitNum() const { + const auto &ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"), + rank); +} + +SmallVector CGAEncodingAttr::getCTAOrder() const { + auto rank = getRank(); + SmallVector defaultOrder(rank); + std::iota(defaultOrder.begin(), defaultOrder.end(), 0); + return orderPerDimImpl(getLinearLayout(), + StringAttr::get(getContext(), "block"), defaultOrder); +} + +LogicalResult BlockedEncodingAttr::verify( + function_ref emitError, + ArrayRef sizePerThread, ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, ArrayRef order, + CGAEncodingAttr CGALayout) { + if (!llvm::all_equal({sizePerThread.size(), threadsPerWarp.size(), + warpsPerCTA.size(), order.size()})) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + if (llvm::any_of(sizePerThread, + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { + return emitError() + << "Every element in sizePerThread must be a power of two."; + } + if (llvm::any_of(threadsPerWarp, + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { + return emitError() + << "Every element in threadsPerWarp must be a power of two."; + } + if (llvm::any_of(warpsPerCTA, + [](unsigned x) { return !llvm::isPowerOf2_64(x); })) { + return emitError() + << "Every element in warpsPerCTA must be a power of two."; + } + + // Empty CGALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (order.size() != CGALayout.getRank()) { + return emitError() << "BlockedEncodingAttr and CGALayout's fields must " + "have the same rank."; + } + return verifyLayoutOrder(emitError, order); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl, + LinearLayout &outLl, bool fwdInference, int axis, + std::optional loc) { + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(inLl.getOutDimNames()); + if (fwdInference) { + auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]); + outLl = split * inLl; + } else { + // Assert that there is a dimension with size 2 in the axis + // that has contiguous elements + // Note that this is more general than the fwdInference case in that + // - It allows the dimension not to be the fastest running + // - It allows broadcasting + // In general, this allows us to split along any axis as long as + // the basis (0, 0, ..., 0, 1, 0, ..., 0) is in the registers. + bool found = false; + LinearLayout::BasesT newBases; + for (const auto &basesDim : inLl.getBases()) { + std::vector> newBasesDim; + for (auto base : basesDim.second) { + if (base[axis] == 1 && basesDim.first == kRegister) { + found = true; + continue; + } + base[axis] /= 2; + newBasesDim.push_back(std::move(base)); + } + newBases.insert({basesDim.first, std::move(newBasesDim)}); + } + if (!found) + return emitOptionalError(loc, + "Fp4ToFpOp/SplitOp requires at least 2 elements " + "per thread in the axis/last dimension"); + outLl = LinearLayout(std::move(newBases), std::move(outDims)); + } + return success(); +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr, + Type &value, StringRef desc) { + auto typeAttr = mlir::dyn_cast(attr.getValue()); + if (!typeAttr) { + parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc; + return failure(); + } + value = typeAttr.getValue(); + return success(); +} + +std::optional parseLinearLayout(const DictionaryAttr &dict, + AsmParser &parser, + ArrayRef inDimNames, + int serializedRank = 0) { + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + if (!value) { + parser.emitError(parser.getCurrentLocation(), "Expected basis of '") + << inDimName.getValue() << "' not found"; + return {}; + } + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + if (rank == 0 && serializedRank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + if (rank == 0) { + rank = serializedRank; + } else if (serializedRank != 0 && serializedRank != rank) { + parser.emitError(parser.getCurrentLocation(), + "Serialized rank and rank deduced from LL need to match"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + return LinearLayout(std::move(bases), std::move(outDimNames)); +} + +// We don't use the default implementation as it's a bit too verbose +// This prints in the following format that is shape agnostic, in the sense +// that we don't print explicitly the outShape of the LL +// We always assume LLs to be surjective +// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], +// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], +// warp = [[16, 0], [32, 0]], +// block = []}> +static void printLinearLayout(AsmPrinter &printer, const LinearLayout &ll, + bool skipEmptyBases = false) { + auto bases = ll.getBases(); + if (skipEmptyBases) { + decltype(bases) filtered; + for (auto &kv : bases) + if (!kv.second.empty()) + filtered.insert(kv); + bases = std::move(filtered); + } + + // Printing code unchanged (just prints `bases` instead of `ll.getBases()`). + printer << join(bases, ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }); +} + +// Print the CGA encoding as `CGALayout = [[...]]` when the layout is +// non-trivial. +static void maybePrintCGALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, + CGAEncodingAttr layout) { + if (layout.getLinearLayout().getTotalInDimSize() == 1) + return; + + auto kBlock = StringAttr::get(context, "block"); + const auto &basesMap = layout.getLinearLayout().getBases(); + auto it = basesMap.find(kBlock); + assert(it != basesMap.end()); + const auto &bases = it->second; + // This is the default layout + assert(!bases.empty()); + + printer << ", CGALayout = ["; + llvm::interleaveComma(bases, printer, [&](const std::vector &vec) { + printer << "["; + llvm::interleaveComma(vec, printer); + printer << "]"; + }); + printer << "]"; +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" +#undef GET_ATTRDEF_CLASSES + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +std::optional parseCGAAttr(AsmParser &parser, Attribute attr, + unsigned rank) { + if (!attr) + return CGAEncodingAttr::get1CTALayout(parser.getContext(), rank); + + auto array = llvm::dyn_cast(attr); + if (!array) { + parser.emitError(parser.getNameLoc(), + "expected array value for 'CGALayout'"); + return {}; + } + + auto ctx = parser.getContext(); + auto cgaName = StringAttr::get(ctx, "CGALayout"); + std::vector> bases; + bases.reserve(array.size()); + for (Attribute vecAttr : array) { + SmallVector basisValues; + NamedAttribute basisAttr(cgaName, vecAttr); + if (parseIntArrayAttr(parser, basisAttr, basisValues, "CGALayout entry") + .failed()) + return {}; + if (basisValues.size() != rank) { + parser.emitError(parser.getNameLoc()) + << "'CGALayout' entry length does not match rank " << rank; + return {}; + } + std::vector basis; + basis.reserve(basisValues.size()); + for (unsigned value : basisValues) + basis.push_back(static_cast(value)); + bases.push_back(std::move(basis)); + } + + LinearLayout::BasesT namedBases; + namedBases.insert( + std::make_pair(StringAttr::get(ctx, "block"), std::move(bases))); + LinearLayout ll(namedBases, standardOutDimNames(ctx, rank)); + return CGAEncodingAttr::get(ctx, std::move(ll)); +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + Attribute cgaAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/sizePerThread.size()); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, *CGALayout); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]"; + + maybePrintCGALayout(getContext(), printer, getCGALayout()); + + printer << "}>"; +} + +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + LinearLayout withoutBroadcast = linearLayout; + for (auto inDim : linearLayout.getInDimNames()) { + withoutBroadcast = withoutBroadcast.removeZeroBasesAlongDim(inDim); + } + if (!withoutBroadcast.isInvertible()) { + return emitError() + << "After removing the zero bases the layout must be bijective"; + } + + return success(); +} + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} + +//===----------------------------------------------------------------------===// +// Linear Encoding +//===----------------------------------------------------------------------===// + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{"; + printLinearLayout(printer, getLinearLayout()); + printer << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + std::vector inDimNames = {"register", "lane", "warp", "block"}; + auto maybeLL = parseLinearLayout(dict, parser, inDimNames); + if (!maybeLL.has_value()) + return {}; + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(*maybeLL)); +} + +static SmallVector +basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName, + size_t rank, bool skipBroadcast) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = 0; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector +LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const { + const auto &ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); +} + +CGAEncodingAttr linearToCGAEncodingAttr(const LinearLayout &ll, + ArrayRef cgaLogicalShape) { + // Compute the shapePerCTA + auto shape = ll.getOutDims(); + for (int i = 0; i < shape.size(); ++i) { + shape[i].second /= cgaLogicalShape[i]; + } + auto inDims = to_vector(ll.getInDimNames()); + auto kBlock = inDims.back(); + assert(kBlock.str() == "block"); + inDims.pop_back(); + auto outDims = to_vector(ll.getOutDimNames()); + auto subLl = ll.sublayout(inDims, outDims); + // sublayout returns the same output size. We trim it to the + // real size + subLl = LinearLayout(subLl.getBases(), shape, false); + // The cgaLayout is what we get after dividing on the left by + // the layout in a single CTA. + auto maybeCgaLayout = divideLeft(ll, subLl); + assert(maybeCgaLayout.has_value()); + auto *ctx = inDims[0].getContext(); + auto cgaLayout = maybeCgaLayout->sublayout({kBlock}, outDims); + return CGAEncodingAttr::get(ctx, std::move(cgaLayout)); +} + +SmallVector +LinearEncodingAttr::orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const { + return orderPerDimImpl(getLinearLayout(), dimName, defaultOrder); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} + +CGAEncodingAttr LinearEncodingAttr::getCGALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCGAEncodingAttr(getLinearLayout(), splitNum); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder()); +} + +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getOrder().size(); + const auto &ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto splitNum = getCGALayout().getCTASplitNum(); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : llvm::zip(ll.getOutDimSizes(), splitNum)) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[kRegister]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDimImpl(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(StringAttr::get(getContext(), "register"), order); +} + +LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape) const { + // When broadcasting the layout the shape changes, otherwise the shape is + // the same as the shape of the tensor + // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep + // the invariant that the shape of the LL is that of the tensor + // We choose the former for BC + auto scaledLayout = get(getContext(), toLinearLayout(shape)); + auto kRegister = StringAttr::get(getContext(), "register"); + return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false); +} + +SmallVector +LinearEncodingAttr::getContig(const char *inDim, + SmallVector lowerContig) const { + const auto &ll = getLinearLayout(); + const auto &bases = + ll.getBases().find(StringAttr::get(getContext(), inDim))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(lowerContig); + auto basisIt = bases.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = contig[dim]; + + while (basisIt != bases.end() && *basisIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++basisIt; + } + } + return contig; +} + +SmallVector LinearEncodingAttr::getContigPerThread() const { + SmallVector contig(getOrder().size(), 1); + return getContig("register", contig); +} + +SmallVector LinearEncodingAttr::getContigPerWarp() const { + return getContig("lane", getContigPerThread()); +} + +unsigned +LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape) const { + return product(getElemsPerThread(shape)); +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + Attribute cgaAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + continue; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/warpsPerCTA.size()); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CGALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCGALayout(getContext(), printer, getCGALayout()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MUSA WMMA / SQMMA encoding +//===----------------------------------------------------------------------===// + +Attribute MUSAWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + Attribute cgaAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + continue; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/warpsPerCTA.size()); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CGALayout, + instrShape); +} + +void MUSAWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() << ", warpsPerCTA = [" + << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCGALayout(getContext(), printer, getCGALayout()); + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +LogicalResult MUSAWmmaEncodingAttr::verify( + function_ref emitError, unsigned versionMajor, + unsigned versionMinor, ArrayRef warpsPerCTA, + CGAEncodingAttr cgaLayout, ArrayRef instrShape) { + if (warpsPerCTA.empty()) + return emitError() << "warpsPerCTA must be non-empty"; + if (warpsPerCTA.size() > 3) + return emitError() << "warpsPerCTA rank must be <= 3"; + if (instrShape.size() != 3) + return emitError() << "instrShape must have rank 3"; + if (versionMajor == 0) + return emitError() << "MUSA WMMA versionMajor must be non-zero"; + if (instrShape[0] == 0 || (instrShape[0] % 8) != 0) + return emitError() << "WMMA instrShape[M] must be a non-zero multiple of 8"; + if (instrShape[1] == 0 || (instrShape[1] % 8) != 0) + return emitError() << "WMMA instrShape[N] must be a non-zero multiple of 8"; + if (instrShape[2] == 0 || ((instrShape[2] % 8) != 0 && instrShape[2] != 4)) + return emitError() + << "WMMA instrShape[K] must be 4 or a non-zero multiple of 8"; + (void)versionMinor; + (void)cgaLayout; + return success(); +} + +Attribute MUSASqmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + Attribute cgaAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + continue; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/warpsPerCTA.size()); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CGALayout, + instrShape); +} + +void MUSASqmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() << ", warpsPerCTA = [" + << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCGALayout(getContext(), printer, getCGALayout()); + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +LogicalResult MUSASqmmaEncodingAttr::verify( + function_ref emitError, unsigned versionMajor, + unsigned versionMinor, ArrayRef warpsPerCTA, + CGAEncodingAttr cgaLayout, ArrayRef instrShape) { + if (warpsPerCTA.empty()) + return emitError() << "warpsPerCTA must be non-empty"; + if (warpsPerCTA.size() > 3) + return emitError() << "warpsPerCTA rank must be <= 3"; + if (instrShape.size() != 3) + return emitError() << "instrShape must have rank 3"; + if (versionMajor != 3) { + return emitError() << "unsupported MUSA SQMMA versionMajor: " + << versionMajor; + } + if (warpsPerCTA.size() < 2) + return emitError() << "SQMMA expects warpsPerCTA rank >= 2"; + if (warpsPerCTA[0] % 4 != 0) + return emitError() << "SQMMA expects warpsPerCTA[0] to be a multiple of 4"; + // Keep instrShape in logical (M, N, K). PH1 still executes with 4-warp + // squads, but the public encoding should not expose the historical M/4 + // compression used by older lowering paths. + if (instrShape[0] == 0 || (instrShape[0] % 8) != 0) + return emitError() + << "SQMMA instrShape[M] must be a non-zero multiple of 8"; + for (unsigned idx = 1; idx < instrShape.size(); ++idx) { + if (instrShape[idx] == 0 || (instrShape[idx] % 8) != 0) + return emitError() + << "SQMMA instrShape[N/K] must be non-zero multiples of 8"; + } + (void)versionMinor; + (void)cgaLayout; + return success(); +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned version = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + SmallVector tilesPerWarp = {}; + unsigned elementBitWidth = 32; + Attribute cgaAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + continue; + } + if (attr.getName() == "tilesPerWarp") { + if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp") + .failed()) + return {}; + } + if (attr.getName() == "elementBitWidth") { + if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed()) + return {}; + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/warpsPerCTA.size()); + if (!CGALayout.has_value()) + return {}; + + if (tilesPerWarp.empty()) + tilesPerWarp = SmallVector(instrShape.size(), 1); + + return parser.getChecked( + parser.getContext(), version, warpsPerCTA, instrShape, isTransposed, + *CGALayout, tilesPerWarp, elementBitWidth); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "version = " << getVersion() // + << ", warpsPerCTA = [" << getWarpsPerCTA() << "]" // + << ", instrShape = [" << getInstrShape() << "]"; + + printer << ", isTransposed = " << getIsTransposed(); + + maybePrintCGALayout(getContext(), printer, getCGALayout()); + + auto tilesPerWarp = getTilesPerWarp(); + if (!hasUnitTilesPerWarp()) + printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]"; + + auto elementBitWidth = getElementBitWidth(); + if (elementBitWidth != 32) + printer << ", elementBitWidth = " << elementBitWidth; + + printer << "}>"; +} + +LogicalResult AMDMfmaEncodingAttr::verify( + function_ref emitError, unsigned version, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef instrShape, bool isTransposed, + mlir::triton::gpu::CGAEncodingAttr, + llvm::ArrayRef tilesPerWarp, unsigned elementBitWidth) { + if (!(version >= 0 && version <= 4)) { + return emitError() << "version must be in the [0, 4] range"; + } + + auto mDim = instrShape[0]; + auto nDim = instrShape[1]; + const std::array, 4> validDims = { + {{32, 32}, {16, 16}, {64, 4}, {4, 64}}}; + if (!llvm::is_contained(validDims, std::make_pair(mDim, nDim))) { + return emitError() << "invalid (mDim, nDim) combination: (" << mDim << ", " + << nDim << ")"; + } + + if (!(elementBitWidth == 32 || elementBitWidth == 64)) + return emitError() << "elementBitWidth must be 32 or 64"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned version = 0; + unsigned rank = 2; + bool isTransposed = false; + SmallVector instrShape = getDefaultInstrShape(); + Attribute cgaAttr = nullptr; + Attribute warpLayAttr = nullptr; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) + return {}; + } + if (attr.getName() == "rank") { + if (parseUInt(parser, attr, rank, "rank").failed()) + return {}; + } + if (attr.getName() == "ctaLayout") { + warpLayAttr = attr.getValue(); + continue; + } + if (attr.getName() == "isTranspose") { + if (parseBool(parser, attr, isTransposed, "isTranspose").failed()) + return {}; + } + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + continue; + } + if (attr.getName() == "instrShape") { + instrShape.clear(); + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + if (!warpLayAttr) { + return {}; + } + + auto dictWarpLay = llvm::dyn_cast(warpLayAttr); + if (!dictWarpLay) { + parser.emitError(parser.getNameLoc(), + "expected dictionary value for 'ctaLayout'"); + return {}; + } + + // Enable optional parsing of register dimension, since it's almost always + // size 1 dim. + auto ctx = parser.getContext(); + LinearLayout ctaLL; + std::vector inDimNames; + auto kReg = StringAttr::get(ctx, "register"); + Attribute value = dictWarpLay.get(kReg); + if (!value) { + ctaLL = parseLinearLayout(dictWarpLay, parser, {"warp"}, rank).value(); + auto outDims = standardOutDimNames(ctx, rank); + auto regsLL = LinearLayout::identity1D(1, kReg, outDims[rank - 1]); + ctaLL = regsLL * ctaLL; + } else { + ctaLL = parseLinearLayout(dictWarpLay, parser, {"register", "warp"}, rank) + .value(); + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, /*rank=*/rank); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), version, + std::move(ctaLL), isTransposed, + *CGALayout, instrShape); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "version = " << getVersion() + << ", isTranspose = " << getIsTransposed() << ", ctaLayout = {"; + + printLinearLayout(printer, getCtaLayout(), /*skipEmptyBases*/ true); + + printer << "}"; + + maybePrintCGALayout(getContext(), printer, getCGALayout()); + + if (getInstrShape() != ArrayRef(getDefaultInstrShape())) { + printer << ", instrShape = [" << getInstrShape() << "]"; + } + printer << "}>"; +} + +LogicalResult +AMDWmmaEncodingAttr::verify(function_ref emitError, + unsigned version, LinearLayout ctaLayout, + bool isTransposed, CGAEncodingAttr cgaLayout, + llvm::ArrayRef instrShape) { + if (!(version >= 1 && version <= 3)) + return emitError() << "WMMA version must be in the [1, 3] range"; + + auto shape = SmallVector(instrShape); + auto validShapesV1 = std::vector>{{16, 16, 16}}; + if (version == 1 && !llvm::is_contained(validShapesV1, shape)) + return emitError() << "invalid WMMA v1 instruction shape"; + + auto validShapesV2 = + std::vector>{{16, 16, 16}, {16, 16, 32}}; + if (version == 2 && !llvm::is_contained(validShapesV2, shape)) + return emitError() << "invalid WMMA v2 instruction shape"; + + auto validShapesV3 = std::vector>{ + {16, 16, 4}, {16, 16, 32}, {16, 16, 64}, {16, 16, 128}}; + if (version == 3 && !llvm::is_contained(validShapesV3, shape)) + return emitError() << "invalid WMMA v3 instruction shape"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + auto parent = mlir::dyn_cast(attrs.get("parent")); + if (!parent) { + parser.emitError(parser.getNameLoc(), + "expected a distributed encoding trait"); + return {}; + } + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + +LogicalResult +SliceEncodingAttr::verify(function_ref emitError, + unsigned dim, DistributedEncodingTrait parent) { + unsigned rank = cast(parent).getRank(); + if (rank <= 1) + return emitError() << "parent layout must have at least rank >= 2"; + if (dim >= rank) { + return emitError() << "slice dim=" << dim + << " must be less than the parent rank=" << rank; + } + return success(); +} + +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = getParent().getRepOrder(); + return eraseOrder(parentRepOrder, getDim()); +} + +CGAEncodingAttr SliceEncodingAttr::getCGALayout() const { + auto layout = ::getCGALayout(getParent()).getLinearLayout(); + layout = removeStandardDim(layout, getDim()); + return CGAEncodingAttr::get(getContext(), std::move(layout)); +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; + +template +Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + Attribute cgaAttr = nullptr; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else { + if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + } + + if (auto CGALayout = parseCGAAttr(parser, cgaAttr, order.size())) + return parser.getChecked( + parser.getContext(), vec, perPhase, maxPhase, order, *CGALayout); + return {}; +} + +//===----------------------------------------------------------------------===// +// SwizzledShared encoding +//===----------------------------------------------------------------------===// + +LogicalResult +SwizzledSharedEncodingAttr::verify(function_ref emitError, + unsigned vec, unsigned perPhase, + unsigned maxPhase, ArrayRef order, + CGAEncodingAttr cgaLayout) { + if (order.size() != cgaLayout.getRank()) { + return emitError() << "order size (" << order.size() + << ") must match CGALayout rank (" << cgaLayout.getRank() + << ")"; + } + return verifyLayoutOrder(emitError, order); +} + +Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) { + return parseSwizzledEncoding(parser, type); +} + +void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCGALayout(getContext(), printer, getCGALayout()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// SharedLinear encoding +//===----------------------------------------------------------------------===// + +LogicalResult +SharedLinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout, + unsigned layoutAlignment) { + if (layoutAlignment == 0 || !llvm::isPowerOf2_32(layoutAlignment)) { + return emitError() << "alignment must be a positive power of two"; + } + static const auto expectedInDims = + SmallVector({"offset", "block"}); + for (const auto &[index, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expected] = dims; + if (dim.str() != expected) { + return emitError() << "Expected input dimension " << index << " to be '" + << expected << "'. Got " << dim; + } + } + + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + SmallVector outDimNames = + llvm::to_vector(linearLayout.getOutDimNames()); + if (outDimNames.empty()) { + return emitError() + << "SharedLinearEncodingAttr requires at least one output" + " dimension."; + } + + auto *ctx = outDimNames.front().getContext(); + auto kOffset = StringAttr::get(ctx, "offset"); + auto kBlock = StringAttr::get(ctx, "block"); + + if (!linearLayout.isSurjective()) { + return emitError() << "The layout must be surjective"; + } + + LinearLayout withoutBroadcast = + linearLayout.removeZeroBasesAlongDim(kOffset).removeZeroBasesAlongDim( + kBlock); + if (!withoutBroadcast.isInvertible()) { + return emitError() + << "After removing the zero bases the layout must be bijective"; + } + + return success(); +} + +void SharedLinearEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{"; + auto layout = getLinearLayout(); + auto kBlock = StringAttr::get(getContext(), "block"); + auto kOffset = StringAttr::get(getContext(), "offset"); + if (layout.getBases().lookup(kBlock).empty()) { + layout = + layout.sublayout({kOffset}, llvm::to_vector(layout.getOutDimNames())); + } + printLinearLayout(printer, layout); + printer << "}, alignment = " << getAlignment() << ">"; +} + +Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr layoutDictRaw; + if (parser.parseAttribute(layoutDictRaw).failed()) + return {}; + + if (layoutDictRaw.get("alignment")) { + parser.emitError(parser.getCurrentLocation()) + << "alignment must be specified outside of the linear layout braces"; + return {}; + } + + NamedAttrList layoutAttrList(layoutDictRaw.getValue()); + auto *ctx = parser.getContext(); + auto kBlock = StringAttr::get(ctx, "block"); + if (!layoutAttrList.get(kBlock)) { + layoutAttrList.push_back({kBlock, ArrayAttr::get(ctx, {})}); + } + + DictionaryAttr layoutDict = layoutAttrList.getDictionary(ctx); + + // Parse alignment + unsigned layoutAlignment; + if (parser.parseComma().failed()) + return {}; + if (parser.parseKeyword("alignment").failed() || parser.parseEqual().failed()) + return {}; + if (parser.parseInteger(layoutAlignment).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + std::vector inDimNames = {"offset", "block"}; + auto maybeLL = parseLinearLayout(layoutDict, parser, inDimNames); + if (!maybeLL.has_value()) + return {}; + + // Special case for cleaner errors + if (layoutDict.get("alignment")) { + parser.emitError(parser.getCurrentLocation()) + << "alignment must be specified outside of the linear layout braces"; + return {}; + } + + if (layoutDict.size() != 2) { + parser.emitError(parser.getCurrentLocation()) + << "SharedLinearEncodingAttr must have exactly two attributes: offset " + "and block"; + return {}; + } + + return parser.getChecked( + parser.getContext(), std::move(*maybeLL), layoutAlignment); +} + +SmallVector +SharedLinearEncodingAttr::basesPerDim(StringAttr dimName, + bool skipBroadcast) const { + const auto &ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector +SharedLinearEncodingAttr::orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const { + return orderPerDimImpl(getLinearLayout(), dimName, defaultOrder); +} + +SmallVector SharedLinearEncodingAttr::getOrder() const { + const auto &ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + SmallVector defaultOrder(rank); + std::iota(defaultOrder.rbegin(), defaultOrder.rend(), 0); + return orderPerDim(StringAttr::get(getContext(), "offset"), defaultOrder); +} + +CGAEncodingAttr SharedLinearEncodingAttr::getCGALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCGAEncodingAttr(getLinearLayout(), splitNum); +} +LinearLayout +SharedLinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + const auto &ll = getLinearLayout(); + auto outDimNames = llvm::to_vector(ll.getOutDimNames()); + assert(shape.size() == outDimNames.size()); + // We don't support automatic broadcasting for shared linear layouts + for (auto [size, llSize] : llvm::zip(shape, ll.getOutDimSizes())) { + assert(size == llSize); + } + return ll; +} + +//===----------------------------------------------------------------------===// +// PaddedShared encoding +//===----------------------------------------------------------------------===// + +Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) { + // <[ + if (failed(parser.parseLess()) || failed(parser.parseLSquare())) + return {}; + + // :+ + SmallVector intervals, paddings; + auto parseIntervalPaddingPair = [&]() { + unsigned interval = 0, padding = 0; + if (failed(parser.parseInteger(interval)) || failed(parser.parseColon()) || + failed(parser.parsePlus()) || failed(parser.parseInteger(padding))) + return failure(); + intervals.push_back(interval); + paddings.push_back(padding); + return success(); + }; + // ] + if (failed(parser.parseCommaSeparatedList(parseIntervalPaddingPair)) || + failed(parser.parseRSquare())) + return {}; + + // {} + auto attrList = DictionaryAttr::get(parser.getContext()); + if (failed(parser.parseAttribute(attrList))) + return {}; + + // We have 2 possible formats for the attr-dict: + // 1) offset=[..], block=[..] handled by parseLinearLayout + // 2) order=[..], shape=[..] which creates an identity mapping + + std::optional maybeLL; + // Assume it's the first variant if offset or block is defined + if (attrList.contains("offset") || attrList.contains("block")) { + std::vector inDimNames = {"offset", "block"}; + // Error out on additional attribute names + for (const NamedAttribute &attr : attrList) { + if (!llvm::is_contained(inDimNames, attr.getName())) { + parser.emitError(parser.getCurrentLocation(), "Unexpected attribute ") + << attr.getName() << " found"; + } + } + maybeLL = parseLinearLayout(attrList, parser, inDimNames); + } else { + // Parse the second form + SmallVector order; + SmallVector shape; + for (const NamedAttribute &attr : attrList) { + if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "shape") { + if (parseIntArrayAttr(parser, attr, shape, "shape").failed()) + return {}; + } else { + parser.emitError(parser.getCurrentLocation(), "Unexpected attribute ") + << attr.getName() << " found"; + return {}; + } + } + + if (order.size() != shape.size()) { + parser.emitError(parser.getCurrentLocation(), + "Mismatch of shape and order ranks in padded layout"); + return {}; + } + + // Create identity mapping based on shape and order + auto kOffset = StringAttr::get(parser.getContext(), "offset"); + maybeLL = identityStandardND(kOffset, shape, order); + maybeLL = combineCtaCgaWithShape( + *maybeLL, + CGAEncodingAttr::get1CTALayout(parser.getContext(), shape.size()), + SmallVector(ArrayRef(shape))); + } + + if (!maybeLL.has_value()) + return {}; + + // > + if (parser.parseGreater().failed()) + return {}; + + return parser.getChecked( + parser.getContext(), intervals, paddings, *maybeLL); +} + +void PaddedSharedEncodingAttr::print(AsmPrinter &printer) const { + + auto *ctx = getContext(); + const auto &ll = getLinearComponent(); + + printer << "<["; + llvm::interleaveComma(llvm::zip(getIntervals(), getPaddings()), printer, + [&](std::tuple intervalPad) { + printer << std::get<0>(intervalPad) << ":+" + << std::get<1>(intervalPad); + }); + printer << "] {"; + + // We have a short hand form if linearComponent: + // 1) does have an empty CGA layout (empty block dim) + // 2) offsets are an identity mapping + auto kOffset = StringAttr::get(ctx, "offset"); + auto kBlock = StringAttr::get(ctx, "block"); + auto shape = SmallVector(ll.getOutDimSizes()); + + bool hasEmptyBlock = ll.getInDimSizeLog2(kBlock) == 0; + + LinearLayout identity = identityStandardND(kOffset, shape, getOrder()) + .transposeOuts(to_vector(ll.getOutDimNames())); + auto offsetLayout = ll.sublayout({kOffset}, to_vector(ll.getOutDimNames())); + + if (hasEmptyBlock && offsetLayout == identity) { + printer << "order = [" << ArrayRef(getOrder()) << "], shape = [" + << ArrayRef(shape) << "]"; + } else { + printLinearLayout(printer, getLinearComponent()); + } + + printer << "}>"; +} + +LogicalResult PaddedSharedEncodingAttr::verify( + function_ref emitError, ArrayRef intervals, + ArrayRef paddings, LinearLayout linearComponent) { + if (intervals.size() != paddings.size()) + return emitError() << "intervals size (" << intervals.size() + << ") must match paddings size (" << paddings.size() + << ")"; + + if (intervals.empty()) + return emitError() << "must have at least one interval-padding pair"; + + if (!llvm::all_of(intervals, llvm::isPowerOf2_32)) + return emitError() << "interval values must all be power of two"; + if (!llvm::all_of(paddings, llvm::isPowerOf2_32)) + return emitError() << "padding values must all be power of two"; + + llvm::SmallSet intervalValues(intervals.begin(), + intervals.end()); + if (intervalValues.size() != intervals.size()) + return emitError() << "interval values cannot have duplicates"; + + const auto &ll = linearComponent; + // The linear layout should map from [offset, block] to [dim0..dimN). All + // bases should be 0 or power of twos and move in a single direction without + // broadcasting + + if (ll == LinearLayout::empty()) + return emitError() << "linearComponent cannot be empty"; + + assert(!ll.getInDimNames().empty()); + auto *ctx = ll.getInDimNames().begin()->getContext(); + + if (!llvm::equal(ll.getInDimNames(), + std::array{StringAttr::get(ctx, "offset"), + StringAttr::get(ctx, "block")})) { + return emitError() + << "linearComponent must have [offset, block] as input dims"; + } + + if (!llvm::equal(ll.getOutDimNames(), + standardOutDimNames(ctx, ll.getNumOutDims()))) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]."; + } + + const auto &bases = ll.getBases(); + + // Check that we are not broadcasting or having repeated bases + if (!ll.isInvertible()) { + return emitError() << "Broadcasting is not supported."; + } + + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return llvm::count_if(basis, nonZero) <= 1; + })) { + return emitError() + << "Each offset basis must move in at most one dimension."; + } + // Ensure all non zero elements are a power of 2. Combined with the + // broadcast check above this prevents per element swizzling. The intent of + // the linear component is to rearrange whole rows or cache-line sized + // chunks of rows. + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return llvm::all_of( + basis, [](auto v) { return v == 0 || llvm::isPowerOf2_32(v); }); + })) { + return emitError() << "Each offset basis must be 0 or a power of two."; + } + } + + return success(); +} + +PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get( + MLIRContext *context, ArrayRef> intervalPads, + ArrayRef order, ArrayRef shape, + CGAEncodingAttr cgaLayout) { + auto outDimNames = standardOutDimNames(context, shape.size()); + StringAttr kOffset = StringAttr::get(context, "offset"); + + // Create identity mapping based on shape and order + LinearLayout linearComponent = + identityStandardND(kOffset, SmallVector(shape), order); + linearComponent = combineCtaCgaWithShape(linearComponent, cgaLayout, shape); + + return get(context, intervalPads, std::move(linearComponent)); +} + +PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get( + MLIRContext *context, ArrayRef> intervalPads, + LinearLayout linearComponent) { + SmallVector intervals, paddings; + intervals.reserve(intervalPads.size()); + paddings.reserve(intervalPads.size()); + for (auto [interval, padding] : intervalPads) { + intervals.push_back(interval); + paddings.push_back(padding); + } + return get(context, intervals, paddings, std::move(linearComponent)); +} + +SmallVector +PaddedSharedEncodingAttr::basesPerDim(StringAttr dimName, + bool skipBroadcast) const { + const auto &ll = getLinearComponent(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); +} + +int64_t PaddedSharedEncodingAttr::getPaddedSize(ArrayRef shape) const { + int64_t unpaddedSize = product(shape); + int64_t paddingSize = 0; + for (auto [interval, padding] : + llvm::zip_equal(getIntervals(), getPaddings())) { + paddingSize += (unpaddedSize >> llvm::Log2_32(interval)) + << llvm::Log2_32(padding); + // There is no need for padding after the last element + if (unpaddedSize % interval == 0) + paddingSize -= padding; + } + return unpaddedSize + paddingSize; +} + +SmallVector +PaddedSharedEncodingAttr::orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const { + return orderPerDimImpl(getLinearComponent(), dimName, defaultOrder); +} + +SmallVector PaddedSharedEncodingAttr::getOrder() const { + auto rank = getLinearComponent().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the offsets + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(StringAttr::get(getContext(), "offset"), order); +} + +CGAEncodingAttr PaddedSharedEncodingAttr::getCGALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCGAEncodingAttr(getLinearComponent(), splitNum); +} +//===----------------------------------------------------------------------===// +// NVMMAShared encoding +//===----------------------------------------------------------------------===// + +Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned swizzlingByteWidth; + bool transposed = false; + bool fp4Padded = false; + unsigned elementBitWidth; + unsigned layoutRank = 2; + Attribute cgaAttr = nullptr; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "swizzlingByteWidth") { + if (parseUInt(parser, attr, swizzlingByteWidth, "swizzlingByteWidth") + .failed()) + return {}; + } else if (attr.getName() == "transposed") { + if (parseBool(parser, attr, transposed, "transposed").failed()) + return {}; + } else if (attr.getName() == "elementBitWidth") { + if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed()) + return {}; + } else if (attr.getName() == "fp4Padded") { + if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed()) + return {}; + } else if (attr.getName() == "CGALayout") { + cgaAttr = attr.getValue(); + } else if (attr.getName() == "rank") { + if (parseUInt(parser, attr, layoutRank, "rank").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CGALayout = + parseCGAAttr(parser, cgaAttr, layoutRank); + if (!CGALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth, + fp4Padded, *CGALayout); +} + +void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "swizzlingByteWidth = " << getSwizzlingByteWidth() // + << ", transposed = " << getTransposed() // + << ", elementBitWidth = " << getElementBitWidth(); + if (getFp4Padded()) { + // Print only in this case to reduce the noise for the more common case. + printer << ", fp4Padded = true"; + } + unsigned rank = getCGALayout().getCTAOrder().size(); + auto *ctx = getContext(); + auto defaultLayout = CGAEncodingAttr::get1CTALayout(ctx, rank); + if (getCGALayout() == defaultLayout && rank != 2) { + printer << ", rank = " << rank; + } else { + maybePrintCGALayout(ctx, printer, getCGALayout()); + } + printer << "}>"; +} + +LogicalResult +NVMMASharedEncodingAttr::verify(function_ref emitError, + unsigned swizzlingByteWidth, bool transposed, + unsigned elementBitWidth, bool fp4Padded, + CGAEncodingAttr CGALayout) { + if (elementBitWidth == 0) + return emitError() << "elementBitWidth must be non-zero"; + if (!llvm::is_contained({0, 32, 64, 128}, swizzlingByteWidth)) + return emitError() << "swizzlingByteWidth must be 0, 32, 64, or 128"; + return success(); +} + +int NVMMASharedEncodingAttr::getVec() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return 128 / getElementBitWidth(); +} + +int NVMMASharedEncodingAttr::getPerPhase() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return 128 / getSwizzlingByteWidth(); +} + +int NVMMASharedEncodingAttr::getMaxPhase() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return getSwizzlingByteWidth() / 16; +} + +int32_t NVMMASharedEncodingAttr::getAlignment() const { + return 128 * getMaxPhase(); +} + +//===----------------------------------------------------------------------===// +// AMDRotatingShared encoding +//===----------------------------------------------------------------------===// + +Attribute AMDRotatingSharedEncodingAttr::parse(AsmParser &parser, Type type) { + return parseSwizzledEncoding(parser, type); +} + +void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCGALayout(getContext(), printer, getCGALayout()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const { + return llvm::all_of(getTilesPerWarp(), [](int x) { return x == 1; }); +} + +SmallVector +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { + auto mnkDim = getInstrShape(); + unsigned mDim = mnkDim[0]; + unsigned nDim = mnkDim[1]; + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps. + int kGroups = warpSize / std::min(mDim, nDim); // for 64x4 and 4x64, + // kGroups = 16 + int64_t kDim = kWidth * kGroups; + + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + auto tilesPerWarp = getTilesPerWarp(); + + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * tilesPerWarp[rank - 2] * + warpsPerCTA[rank - 2])) * + tilesPerWarp[rank - 2], + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / + (operandTileShape[1] * tilesPerWarp[rank - 1] * + warpsPerCTA[rank - 1])) * + tilesPerWarp[rank - 1]}; + } +} + +SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned vectorSize, unsigned elemBitWidth, + bool needTrans) const { + int kDimIndex = operandIdx == 0 ? 1 : 0; + + // Disable swizzling for scales + if (operandIdx >= 2) { + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + cgaLayout); + } + + if (needTrans) + kDimIndex = 1 - kDimIndex; + + bool isKContig = sharedOrder[0] == kDimIndex; + // GFX950 supports LDS transpose load instructions, so we need swizzling even + // when K dimension is not the contiguous dimension. + bool isGFX950 = getVersion() == 4; + bool swizzleNonKContig = + isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16); + + if (!isKContig && !swizzleNonKContig) { + // Do not swizzle. In this case accesses will go in different banks even + // without swizzling. + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + cgaLayout); + } + + const unsigned numBanks = isGFX950 ? 64 : 32; + const unsigned bankBitWidth = 32; + const unsigned simdWidth = 16; + + // Number of inner dimension rows per one pattern repeat + int innerDimLength = operandShape[sharedOrder[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int maxPhase = + std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (getInstrShape()[0] == 4) + maxPhase = 4; + + return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase, + maxPhase, sharedOrder, cgaLayout); +} + +//===----------------------------------------------------------------------===// +// Wmma encoding +//===----------------------------------------------------------------------===// + +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +//===----------------------------------------------------------------------===// +// MUSA WMMA / SQMMA encoding +//===----------------------------------------------------------------------===// + +bool MUSAWmmaEncodingAttr::isPH1() const { return getVersionMajor() == 3; } + +bool MUSASqmmaEncodingAttr::isPH1() const { return getVersionMajor() == 3; } + +SmallVector MUSAWmmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +MUSAWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector MUSASqmmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +MUSASqmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +MUSASqmmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { + unsigned rank = shape.size(); + SmallVector elemsPerThread(rank, 1); + if (rank < 2) + return elemsPerThread; + + auto instrMNK = getInstrShape(); + auto warpsPerCTA = getWarpsPerCTA(); + auto shapePerCTA = getShapePerCTA(*this, shape); + auto ceilDiv = [](unsigned x, unsigned y) { + return y == 0 ? 0 : (x + y - 1) / y; + }; + unsigned squadsPerCTA = std::max(1u, warpsPerCTA[0] / 4); + + unsigned repM = ceilDiv(shapePerCTA[rank - 2], instrMNK[0] * squadsPerCTA); + unsigned repN = ceilDiv(shapePerCTA[rank - 1], instrMNK[1] * warpsPerCTA[1]); + repM = std::max(repM, 1u); + repN = std::max(repN, 1u); + + // PH1 SQMMA C/D matrix layout: each thread owns a logical (M/16)x(N/8) + // fragment per repetition. The 4-warp squad decomposition stays in the warp + // basis, not in the public instrShape contract. + elemsPerThread[rank - 2] = (instrMNK[0] / 16) * repM; + elemsPerThread[rank - 1] = (instrMNK[1] / 8) * repN; + return elemsPerThread; +} + +unsigned +MUSASqmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape) const { + return product(getElemsPerThread(shape)); +} + +static SwizzledSharedEncodingAttr +composeMusaSharedLayout(MLIRContext *ctx, CGAEncodingAttr cgaLayout, + int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned elemBitWidth, + bool needTrans) { + // Determine whether the operand is laid out as MN-major (transpose) or + // K-major (non-transpose). For dot operands, row-major means the last + // logical dim is contiguous. + if (sharedOrder.empty()) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + + unsigned rank = sharedOrder.size(); + bool isRowMajor = sharedOrder[0] == (rank - 1); + if (needTrans && rank >= 2) + isRowMajor = !isRowMajor; + + bool isMNMajor = + ((operandIdx == 0) && !isRowMajor) || ((operandIdx == 1) && isRowMajor); + + auto shapePerCTA = getShapePerCTA(cgaLayout.getCTASplitNum(), operandShape); + if (sharedOrder[0] >= shapePerCTA.size()) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + + unsigned elemBytes = std::max(1u, elemBitWidth / 8); + unsigned sgBytes = 16; + if (isMNMajor) { + if (elemBytes == 2) { + sgBytes = 32; + } else if (elemBytes == 4) { + sgBytes = 64; + } + } + + // For non-pipelined SQMMA operands, choose the canonical swizzled_shared + // layout whose LinearLayout is exactly representable by the PH1 TME + // SG/SS/SL tuple that both descriptor landing and local_alloc lowering use. + constexpr int64_t kSwizzleLineBytes = 256; + int64_t leadingWidthBytes = shapePerCTA[sharedOrder[0]] * elemBytes; + if (leadingWidthBytes <= 0) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + + unsigned vec = std::max(1u, sgBytes / elemBytes); + unsigned perPhase = 1; + unsigned maxPhase = + std::max(1u, static_cast(kSwizzleLineBytes / sgBytes)); + + if (leadingWidthBytes >= kSwizzleLineBytes) { + if ((leadingWidthBytes % kSwizzleLineBytes) != 0) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + int64_t factor = leadingWidthBytes / kSwizzleLineBytes; + if (!llvm::isPowerOf2_64(static_cast(factor))) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + int64_t vecElems = factor * static_cast(sgBytes / elemBytes); + int64_t phases = + kSwizzleLineBytes / (factor * static_cast(sgBytes)); + if (vecElems <= 0 || phases <= 0) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + vec = static_cast(vecElems); + perPhase = 1; + maxPhase = static_cast(phases); + } else { + if ((kSwizzleLineBytes % leadingWidthBytes) != 0) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + int64_t ratio = kSwizzleLineBytes / leadingWidthBytes; + if (!llvm::isPowerOf2_64(static_cast(ratio))) + return SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, sharedOrder, + cgaLayout); + perPhase = static_cast(ratio); + } + + return SwizzledSharedEncodingAttr::get(ctx, vec, perPhase, maxPhase, + sharedOrder, cgaLayout); +} + +SwizzledSharedEncodingAttr MUSAWmmaEncodingAttr::composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, unsigned elemBitWidth, + bool needTrans) const { + (void)operandIdx; + (void)operandShape; + (void)kWidth; + (void)elemBitWidth; + (void)needTrans; + // Keep WMMA operand shared layout consistent with the 3.2 PH1 path: + // canonical non-swizzled shared tiles for both A/B operands. + return SwizzledSharedEncodingAttr::get(getContext(), /*vec=*/1, + /*perPhase=*/1, /*maxPhase=*/1, + sharedOrder, cgaLayout); +} + +SwizzledSharedEncodingAttr MUSASqmmaEncodingAttr::composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, unsigned elemBitWidth, + bool needTrans) const { + (void)kWidth; + + // Triton 3.2 MUSA pipelined SQMMA local allocs (shape rank includes the + // stage dimension) used MMAv3-style shared swizzle selection based on the + // contiguous byte width. Keep that behavior for rank-expanded operands. + if (operandShape.size() > sharedOrder.size() && elemBitWidth > 0) { + auto shapePerCTA = getShapePerCTA(cgaLayout.getCTASplitNum(), operandShape); + int64_t contigBytes = shapePerCTA[sharedOrder[0]] * elemBitWidth / 8; + int perPhase = 1; + int maxPhase = 1; + if (contigBytes >= 128 && contigBytes % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigBytes >= 64 && contigBytes % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigBytes >= 32 && contigBytes % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else if (contigBytes >= 16 && contigBytes % 16 == 0) { + perPhase = 8; + maxPhase = 1; + } + unsigned vec = std::max(1u, 128u / elemBitWidth); + return SwizzledSharedEncodingAttr::get(getContext(), vec, perPhase, + maxPhase, sharedOrder, cgaLayout); + } + + return composeMusaSharedLayout(getContext(), cgaLayout, operandIdx, + operandShape, sharedOrder, elemBitWidth, + needTrans); +} + +SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand( + CGAEncodingAttr cgaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned kWidth, unsigned elemBitWidth, + bool needTrans) const { + int kDimIndex = operandIdx == 0 ? 1 : 0; + bool isKContig = sharedOrder[0] == kDimIndex; + + if (!isKContig) { + // Do not swizzle. In this case accesses will go in different banks even + // without swizzling. + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + cgaLayout); + } + + // max vectorization size for ds_load is 128 bits + int vectorSize = std::min(kWidth * elemBitWidth, 128u) / elemBitWidth; + + const int numBanks = 32; + const int bankBitWidth = 32; + + // Number of inner dimension rows per one pattern repeat + int innerDimLength = operandShape[sharedOrder[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // for both RDNA3 and RDNA4, the M/N dimension of wmma is 16 + // This represents the max number of rows that can be accessed + // at the same time + int mDim = getInstrShape()[0]; + int maxPhase = + std::max(std::min(mDim / perPhase, innerDimLength / vectorSize), 1); + + return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase, + maxPhase, sharedOrder, cgaLayout); +} + +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int kWidth, int opIdx) const { + assert(kWidth >= std::max(32 / bitwidth, 1) && + "kWidth must be >= max(32 / bitwidth, 1) for this function to be " + "well-defined"); + auto rank = shape.size(); + // Broadcast long K + auto warpsPerCTA = to_vector(getWarpsPerCTA()); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + warpsPerCTA[kDim] = 1; + + SmallVector tileSize; + if (rank == 3) { + tileSize.push_back(1); + } + // warpSizeK * (warpRepK * VecBitWidth) + auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64); + if (opIdx == 0) { + // m x k + tileSize.push_back(16); + tileSize.push_back(tileBitWidthK / bitwidth); + } else { + // k x n + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF + // so it's fine if the n is incorrect here + tileSize.push_back(tileBitWidthK / bitwidth); + tileSize.push_back(8); + } + + SmallVector numRep; + // Lezcano: This is odd. Why do we always return a vector of size 3? + if (rank != 3) { + numRep.push_back(1); + } + for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) { + numRep.push_back(std::max(1, s / (size * warp))); + } + return numRep; +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// + +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } else if (auto blocked = mlir::dyn_cast(getParent())) { + return to_vector(blocked.getOrder()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + +CGAEncodingAttr DotOperandEncodingAttr::getCGALayout() const { + const auto &layout = ::getCGALayout(getParent()).getLinearLayout(); + auto bases = layout.getBases(); + auto kBlock = StringAttr::get(getContext(), "block"); + auto &blockBases = bases[kBlock]; + auto rank = layout.getNumOutDims(); + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + for (auto &basis : blockBases) { + basis[kDim] = 0; + } + auto dims = layout.getOutDims(); + dims[kDim].second = 1; + return CGAEncodingAttr::get(getContext(), + LinearLayout(std::move(bases), dims, true)); +} +LogicalResult DotOperandEncodingAttr::verify( + function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned opIdx, + Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "ttg.dot_op parent parameter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (parentAttr.getVersion() == 1 && (kWidth != 8 && kWidth != 16)) + return emitError() + << "ttg.dot_op kWidth parameter must be 8/16 for WMMA v1 " + "(including packed cases for `scaled_dot`)"; + if (parentAttr.getVersion() == 2 && !llvm::is_contained({4, 8, 16}, kWidth)) + return emitError() + << "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 " + "(including packed cases for `scaled_dot`)"; + if (parentAttr.getVersion() == 3 && kWidth == 0) + return emitError() + << "ttg.dot_op kWidth parameter is mandatory for WMMA v3 "; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "for MUSA WMMA parent"; + (void)parentAttr; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "for MUSA SQMMA parent"; + (void)parentAttr; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + // Encoding attributes + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + // Memory space attributes + if (auto smem = mlir::dyn_cast(attr)) { + os << "smem"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional loc) const override { + resultEncoding = + SliceEncodingAttr::get(getDialect()->getContext(), axis, + cast(operandEncoding)); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef shape, + ArrayRef order, Attribute &resultEncoding, + std::optional loc) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + if (isIota(order)) { + resultEncoding = operandEncoding; + return success(); + } + if (shape.size() != order.size()) { + return emitOptionalError(loc, "shape and order rank do not match: ", + shape.size(), " vs ", order.size()); + } + auto checkRank = [&](unsigned rank) { + if (rank != order.size()) { + return emitOptionalError(loc, "rank of encoding does not match order: ", + rank, " vs ", order.size()); + } + return success(); + }; + auto *ctx = getDialect()->getContext(); + + auto permuteCGALayout = [ctx](CGAEncodingAttr layout, + ArrayRef order) { + auto ll = transposeLinearLayout(layout.getLinearLayout(), order); + return CGAEncodingAttr::get(ctx, std::move(ll)); + }; + + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + if (auto enc = dyn_cast(operandEncoding)) { + if (failed(checkRank(enc.getCGALayout().getRank()))) + return failure(); + + CGAEncodingAttr cgaLayout = permuteCGALayout(enc.getCGALayout(), order); + resultEncoding = SwizzledSharedEncodingAttr::get( + ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), + applyPermutation(invOrderUnsigned, enc.getOrder()), cgaLayout); + return success(); + } + + if (auto enc = dyn_cast(operandEncoding)) { + if (order == ArrayRef({1, 0})) { + if (failed(checkRank(enc.getCGALayout().getRank()))) + return failure(); + + CGAEncodingAttr cgaLayout = permuteCGALayout(enc.getCGALayout(), order); + resultEncoding = NVMMASharedEncodingAttr::get( + ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(), + enc.getElementBitWidth(), enc.getFp4Padded(), cgaLayout); + return success(); + } + } + + if (auto enc = dyn_cast(operandEncoding)) { + if (failed(checkRank(enc.getCGALayout().getRank()))) + return failure(); + + CGAEncodingAttr cgaLayout = permuteCGALayout(enc.getCGALayout(), order); + resultEncoding = BlockedEncodingAttr::get( + ctx, applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), cgaLayout); + return success(); + } + // Generic case + auto padded = dyn_cast(operandEncoding); + + auto ll = padded ? padded.getLinearComponent() + : toLinearLayout(shape, operandEncoding); + if (failed(checkRank(ll.getNumOutDims()))) + return failure(); + auto transposedLl = transposeLinearLayout(ll, order); + if (isa(operandEncoding)) { + resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl)); + } else if (padded) { + resultEncoding = PaddedSharedEncodingAttr::get(ctx, padded.getIntervals(), + padded.getPaddings(), + std::move(transposedLl)); + } else { + auto shared = cast(operandEncoding); + resultEncoding = SharedLinearEncodingAttr::get( + ctx, std::move(transposedLl), shared.getAlignment()); + } + return success(); + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa( + operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (mlir::isa(retEncoding)) { + // SQMMA operands cross an explicit shared-memory or dot-operand + // boundary. Do not accept raw MMA accumulator/fragment layouts here: + // chained dots must first re-enter a logical tensor/shared descriptor + // layout before they can be staged as the next operand. + if (mlir::isa( + operandEncoding)) + return success(); + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!dotOpEnc) + return emitOptionalError( + location, "unexpected operand layout for MUSASqmmaEncodingAttr"); + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + return success(); + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + // Verify that the encodings are valid. + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + + // Check if we have already selected an MMA version for Nvidia. If so, + // validate that the encodings are correct and compatible. + auto mmaAEncoding = + dyn_cast_or_null(aEncoding.getParent()); + auto mmaBEncoding = + dyn_cast_or_null(bEncoding.getParent()); + auto dotOp = dyn_cast(op); + if (!dotOp) + return op->emitError( + "expected a dot-like operation for encoding compatibility checks"); + auto resEnc = cast(dotOp.getD().getType()) + .getEncoding(); + auto mmaResEncoding = dyn_cast(resEnc); + if (mmaAEncoding || mmaBEncoding || mmaResEncoding) { + // Check that they are all set and have the same version. + if (!mmaAEncoding || !mmaBEncoding || !mmaResEncoding) + return op->emitError("mismatching MMA encoding"); + auto mmaBEncoding = cast(bEncoding.getParent()); + if (mmaAEncoding.getVersionMajor() != mmaBEncoding.getVersionMajor() || + mmaAEncoding.getVersionMajor() != mmaResEncoding.getVersionMajor()) { + return op->emitError("mismatched MMA version."); + } + // Verify that the operands are supported on the selected MMA version. + auto ttDot = dyn_cast(op); + if (!ttDot) + return op->emitError( + "expected tt.dot when verifying Nvidia MMA operand support"); + if (!supportMMA(ttDot, mmaResEncoding.getVersionMajor())) + return op->emitError("unsupported MMA version"); + } + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // Using legacy layouts, a dst encoding that satisfies this property may not + // exist. Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // With linear layouts, we can always find a dst encoding that satisfies + // this property. See inferReshapeOpEncoding. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) const { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return failure(); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCGALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + auto srcCGALayout = src.getCGALayout(); + if (!all_of(srcCGALayout.getCTAsPerCGA(), + [](int32_t x) { return x == 1; }) || + !all_of(srcCGALayout.getCTASplitNum(), + [](int32_t x) { return x == 1; })) { + return failure(); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return failure(); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return failure(); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CGALayout can be all 1's because we bailed on multi-CGA layouts above. + auto CGALayout = + CGAEncodingAttr::get1CTALayout(src.getContext(), dstShape.size()); + + dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, + dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CGALayout); + + return success(); + } + + LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, + std::optional loc) const override { + if (expected == got) { + return success(); + } + if (!expected || !got) + return failure(); + + // Check whether the encodings are structurally the same. + if (!areLayoutsEquivalent(shape, cast(expected), + cast(got))) { + return emitOptionalError(loc, "Expected result encoding ", expected, + " but was ", got); + } + return success(); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + if (product(srcShape) != product(dstShape)) { + return emitOptionalError(loc, "numel of dst shape does not match " + "numel of src shape"); + } + auto result = + inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); + if (succeeded(result)) { + return result; + } + if (!isa(srcEnc)) { + return emitOptionalError(loc, + "Failed MemDescReshapeOp encoding inference"); + } + // If the legacy encoding failed use LinearLayouts. + // Once LinearLayouts are more widely used, we can remove + // inferReshapeOpLegacyEncoding and simply use LLs. + + // HACK: We create a dummy tensor type to pass to inferReshapeLinearLayout. + auto ctx = srcEnc.getContext(); + auto fp32Type = IntegerType::get(ctx, 32, IntegerType::Unsigned); + auto srcTy = RankedTensorType::get(srcShape, fp32Type, srcEnc); + LinearLayout ll = + inferReshapeLinearLayout(cast(srcTy), dstShape); + + dstEnc = LinearEncodingAttr::get(srcEnc.getContext(), std::move(ll)); + return success(); + } + + LogicalResult + inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + auto ctx = getContext(); + if (auto enc = mlir::dyn_cast(srcEnc); + enc && enc.getDim() == shape.size()) { + SmallVector joinedShape(shape); + joinedShape.push_back(2); + auto parent = enc.getParent(); + auto parentLL = toLinearLayout(joinedShape, parent); + + Attribute splitEnc; + auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc); + if (succeeded(result) && + areLayoutsEquivalent(shape, cast(splitEnc), + cast(srcEnc))) { + dstEnc = parent; + return success(); + } + } else if (auto enc = mlir::dyn_cast(srcEnc)) { + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is the fastest running + // dimension. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMajorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + auto ctall = enc.getCGALayout().getLinearLayout(); + auto kBlock = StringAttr::get(enc.getContext(), "block"); + auto newDim = standardOutDimNames( + enc.getContext(), ctall.getNumOutDims() + 1)[ctall.getNumOutDims()]; + ctall *= LinearLayout::identity1D(1, kBlock, newDim); + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), append(enc.getSizePerThread(), 2), + append(enc.getThreadsPerWarp(), 1), append(enc.getWarpsPerCTA(), 1), + appendMajorDim(enc.getOrder()), + CGAEncodingAttr::get(enc.getContext(), std::move(ctall))); + return success(); + } + + // Append dim to shape + auto ll = toLinearLayout(shape, srcEnc); + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.push_back(1); + ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + + // Try join on last dim + auto axis = dstShape.size() - 1; + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc); + + assert(result.succeeded()); + dstEnc = LinearEncodingAttr::get(ctx, std::move(newLl)); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be the fastest running dimension. The result + // encoding is the same as the input, but with the last dimension removed. + auto enc = mlir::dyn_cast(srcEnc); + bool isSimpleSplit = (enc && (enc.getSizePerThread().back() == 2) && + (enc.getThreadsPerWarp().back() == 1) && + (enc.getWarpsPerCTA().back() == 1) && + (enc.getCGALayout().getCTAsPerCGA().back() == 1)); + if (isSimpleSplit) { + SmallVector newOrder(enc.getOrder()); + auto ctall = enc.getCGALayout().getLinearLayout(); + int splitDim = newOrder.size() - 1; + // Remove splitDim from order. + newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim), + newOrder.end()); + // Remove last dimension from ctall. + ctall = ctall.unsqueezeOut(to_vector(ctall.getOutDimNames()).back()); + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder), + CGAEncodingAttr::get(enc.getContext(), std::move(ctall))); + return success(); + } + + auto axis = shape.size() - 1; + if (shape[axis] != 2) { + return emitOptionalError( + loc, "SplitOp input shape should have 2 in the last dim"); + } + + auto ctx = getContext(); + + // Split on last dim + auto ll = toLinearLayout(shape, srcEnc); + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc); + if (!result.succeeded()) { + return failure(); + } + // Remove last dim from newLl (which should be 1) + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.pop_back(); + newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + dstEnc = LinearEncodingAttr::get(ctx, std::move(newLl)); + return success(); + } + + LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, + Attribute &outEnc, bool fwdInference, + std::optional loc) const override { + // We implement two legacy layout propagations + // Once we fully migrate to LinearLayouts, we can remove these. + auto *ctx = getContext(); + // The output encoding will only be a legacy encoding if the axis is the + // fastest running dimension. + // FIXME: We should make sure that there are enough elements along the axis + // axis whenever fwdInference is false + if (getOrder(cast(inEnc), shape)[axis] == 0) { + // Dot operand: double kWidth if kDim == axis. + if (auto dotEnc = mlir::dyn_cast(inEnc)) { + auto kWidth = dotEnc.getKWidth(); + if (fwdInference) { + kWidth *= 2; + } else { + if (kWidth > 1) { + // bwd inference + kWidth /= 2; + } else { + return emitOptionalError(loc, + "Fp4ToFpOp requires at least 2 elements " + "per thread in the axis dimension"); + } + } + outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(), + dotEnc.getParent(), kWidth); + return success(); + } + + // Blocked layout: double elemsPerThread[axis]. + if (auto blockedEnc = mlir::dyn_cast(inEnc)) { + auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread()); + if (fwdInference) { + sizePerThread[axis] *= 2; + } else { + if (sizePerThread[axis] > 1) { + sizePerThread[axis] /= 2; + } else { + return emitOptionalError( + loc, "Fp4ToFpOp requires at least 2 elements per " + "thread in the axis dimension"); + } + } + outEnc = BlockedEncodingAttr::get( + ctx, sizePerThread, blockedEnc.getThreadsPerWarp(), + blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(), + blockedEnc.getCGALayout()); + return success(); + } + } + + auto ll = toLinearLayout(shape, inEnc); + auto newLl = LinearLayout::empty(); + auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc); + if (!result.succeeded()) + return result; + outEnc = LinearEncodingAttr::get(ctx, std::move(newLl)); + return success(); + } +}; + +struct TritonGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, Operation *op, + function_ref makeErr) const override { + auto distr = dyn_cast(layout); + if (!distr) + return makeErr() + << "Non-distributed layout is not allowed in tensor type."; + auto rank = distr.getRepOrder().size(); + if (rank != rankedTy.getRank()) + return makeErr() << "Layout has rank " << rank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + if (llvm::any_of(rankedTy.getShape(), + [](int64_t i) { return !llvm::isPowerOf2_64(i); })) { + return makeErr() << "Layout has shape " << rankedTy.getShape() + << ", but the tensor it's attached to has shape " + << rankedTy.getShape() + << " which is not a power of two."; + } + auto ll = toLinearLayout(rankedTy); + ModuleOp module = op->getParentOfType(); + + // Number of threads per warp. + auto kLane = StringAttr::get(module.getContext(), "lane"); + int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane) + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + // Number of warps per CTA. + std::optional moduleWarpsPerCTA = maybeLookupNumWarps(op); + if (!moduleWarpsPerCTA) { + return makeErr() + << "Could not determine the number of warps per CTA. Operation " + "is not in a context with `ttg.num-warps`."; + } + auto kWarp = StringAttr::get(module.getContext(), "warp"); + int layoutWarpsPerCTA = ll.getInDimSize(kWarp); + bool allowMusaWmmaDotWarpSubset = false; + if (layoutWarpsPerCTA != *moduleWarpsPerCTA) { + if (auto dotEnc = dyn_cast(layout)) { + if (auto musaWmma = + dyn_cast(dotEnc.getParent())) { + // PH1 WMMA dot operands can intentionally model only a subset of CTA + // warps along a single operand axis (e.g. B operand with tileN=1). + // Allow this as long as it evenly partitions the CTA warp count. + if (musaWmma.isPH1() && layoutWarpsPerCTA > 0 && + (*moduleWarpsPerCTA % layoutWarpsPerCTA) == 0) { + allowMusaWmmaDotWarpSubset = true; + } + } + } + } + if (!allowMusaWmmaDotWarpSubset && + layoutWarpsPerCTA != *moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp) + << " warps per CTA, but the context requires " + << *moduleWarpsPerCTA << " warps per CTA."; + } + + // Number of CTAs per CGA. + auto kBlock = StringAttr::get(module.getContext(), "block"); + int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); + if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock) + << " CTAs per CGA, but the context requires " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + return success(); + } + + LogicalResult verifyMemDescLayout( + Attribute layout, Type type, Operation *op, + function_ref makeErr) const override { + auto memDescTy = dyn_cast(type); + if (!memDescTy) + return makeErr() << "Non-memdesc layout is not allowed in memdesc type."; + + auto kBlock = StringAttr::get(op->getContext(), "block"); + int nCTAsLayout; + int tensorSize = 1; + int tensorRank = 0; + if (auto sharedLinearEnc = dyn_cast(layout)) { + const auto &ll = sharedLinearEnc.getLinearLayout(); + nCTAsLayout = ll.getInDimSize(kBlock); + tensorSize = ll.getTotalOutDimSize(); + tensorRank = ll.getNumOutDims(); + } else { + // It'd be nice to be able to do toLinearLayout, but the multibuffering + // dimension breaks this left right and centre + nCTAsLayout = getCGALayout(layout).getLinearLayout().getInDimSize(kBlock); + tensorSize = getCGALayout(layout).getLinearLayout().getTotalOutDimSize(); + tensorRank = getCGALayout(layout).getLinearLayout().getNumOutDims(); + } + + ModuleOp module = op->getParentOfType(); + // Number of CTAs per CGA. + int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); + if (nCTAsLayout != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has " << nCTAsLayout + << " CTAs per CGA, but the context requires " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + // Use the tensor rank to ignore the multibuffering dimension + auto numElements = product(memDescTy.getAllocShape().take_back(tensorRank)); + if (tensorSize > numElements) { + return makeErr() << layout << ".\nLayout has tensor size at least " + << tensorSize << ", but the memdesc type has " + << numElements << " elements."; + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Layout debug printing +//===----------------------------------------------------------------------===// + +// Return N-D delinearized indices from a linear index. +static SmallVector delinearizeIndex(int64_t idx, + ArrayRef shape) { + SmallVector ret(shape.size()); + for (int i = shape.size() - 1; i >= 0; i--) { + ret[i] = idx % shape[i]; + idx /= shape[i]; + } + return ret; +} + +// Returns how many padding characters are needed for the string representation +// of value to be the same as max. +static int numCharacterPadding(int value, int max) { + return std::to_string(max).size() - std::to_string(value).size(); +} + +// return the string padded to have the same length as max. +static std::string paddedString(int value, int max) { + int nbChar = numCharacterPadding(value, max); + std::string str; + for (int i = 0; i < nbChar; i++) + str += " "; + str += std::to_string(value); + return str; +} + +std::string mlir::triton::gpu::getSharedLayoutStr(LinearLayout &ll, + bool useHWPointOfView) { + // This RankedTensorType is a MemDescType (?!) + auto outDimNames = llvm::to_vector(ll.getOutDimNames()); + auto shape = convertType(llvm::to_vector(ll.getOutDimSizes())); + auto *ctx = outDimNames[0].getContext(); + + StringAttr kOffset = StringAttr::get(ctx, "offset"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + int64_t tensorSize = product(shape); + unsigned numBlocks = ll.getInDimSize(kBlock); + int32_t blockSize = tensorSize / numBlocks; + + // elementMapping is for the non-hw layout, offsetMapping for hw-layout + std::vector elementMapping(tensorSize); + std::vector offsetMapping; + + // Shared layouts are a mapping of (block, offset) --> (...) + + // We can just use a single int to index into elementMapping because + // the 'swizzle' operation rearranges the indices---and we want to keep it + // that way + int32_t idx = 0; + // Enumerate all the offsets for each block + for (int32_t block = 0; block < numBlocks; block++) { + for (int32_t offset = 0; offset < blockSize; offset++) { + SmallVector> inputs = { + {kBlock, block}, + {kOffset, offset}, + }; + + SmallVector> outputs = ll.apply(inputs); + + std::string sharedInfo = "("; + std::string &value = elementMapping[idx]; + + if (!value.empty()) + value += "|"; + + value += "("; + // We can build up both strings (for hw/non-hw layouts) concurrently + for (int i = 0; i < outputs.size(); i++) { + // Based on the formatting from LinearLayout::toString, the format for + // the hw layout is slightly different. HW layouts use "," vs ":". + if (i > 0) { + sharedInfo += ","; + value += ":"; + } + auto index = paddedString(outputs[i].second, shape[i]); + sharedInfo += index; + value += index; + } + value += ")"; + sharedInfo += ")"; + + offsetMapping.push_back(sharedInfo); + + idx++; + } + } + + std::string layoutStr; + + if (!useHWPointOfView) { + int rank = shape.size(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, shape); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % shape[j] != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, shape); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % shape[j] != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % shape.back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ","; + } + } + } else { + // For the HW view here, print the (block, offset) --> (r,c) mapping + uint32_t idx = 0; + for (int32_t block = 0; block < numBlocks; block++) { + layoutStr += "Block: " + std::to_string(block) + ":\n"; + for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) { + layoutStr += "Offset: " + std::to_string(offset) + " -> "; + layoutStr += offsetMapping[idx]; + layoutStr += "\n"; + idx++; + } + } + } + + return layoutStr; +} + +std::string mlir::triton::gpu::getDistributedLayoutStr(LinearLayout &ll, + bool useHWPointOfView) { + auto inDimNames = llvm::to_vector(ll.getInDimNames()); + auto *ctx = inDimNames[0].getContext(); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + int64_t tensorSize = ll.getTotalOutDimSize(); + std::vector elementMapping(tensorSize); + std::vector threadMapping; + auto shape = convertType(llvm::to_vector(ll.getOutDimSizes())); + unsigned threadsPerWarp = ll.getInDimSize(kLane); + unsigned numWarpsPerCTA = ll.getInDimSize(kWarp); + unsigned numBlocks = ll.getInDimSize(kBlock); + int numElementsPerThreads = ll.getInDimSize(kRegister); + for (int blockId = 0; blockId < numBlocks; ++blockId) { + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + SmallVector> inputs = { + {kBlock, blockId}, + {kWarp, warpId}, + {kLane, tid}, + {kRegister, idx}}; + SmallVector> outputs = + ll.apply(inputs); + int32_t linearizedIdx = 0; + int stride = 1; + for (int i = outputs.size() - 1; i >= 0; i--) { + linearizedIdx += outputs[i].second * stride; + stride *= shape[i]; + } + std::string &value = elementMapping[linearizedIdx]; + if (!value.empty()) + value += "|"; + int padding = numCharacterPadding(blockId, numBlocks) + + numCharacterPadding(tid + warpId * threadsPerWarp, + numWarpsPerCTA * threadsPerWarp) + + numCharacterPadding(idx, numElementsPerThreads); + for (int i = 0; i < padding; i++) + value += " "; + if (numBlocks > 1) + value += "B" + std::to_string(blockId) + ":"; + value += "T" + std::to_string(tid + warpId * threadsPerWarp) + ":" + + std::to_string(idx); + // Now also compute the thread mapping. + std::string threadInfo = "("; + for (int i = 0; i < outputs.size(); i++) { + if (i > 0) + threadInfo += ","; + threadInfo += paddedString(outputs[i].second, shape[i]); + } + threadInfo += ")"; + threadMapping.push_back(threadInfo); + } + } + } + } + std::string layoutStr; + if (!useHWPointOfView) { + // Printing the threads containing each elements of the tensor. + int rank = ll.getNumOutDims(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, shape); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % shape[j] != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, shape); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % shape[j] != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % shape.back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ", "; + } + } + } else { + // Printing the elements in each physical reg/warps/threads. + for (int blockId = 0; blockId < numBlocks; blockId++) { + if (numBlocks > 1) + layoutStr += "Block" + std::to_string(blockId) + ":\n"; + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + layoutStr += "Warp" + std::to_string(warpId) + ":\n"; + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + int linearizedIdx = + blockId * numWarpsPerCTA * threadsPerWarp * + numElementsPerThreads + + warpId * threadsPerWarp * numElementsPerThreads + + tid * numElementsPerThreads + idx; + layoutStr += threadMapping[linearizedIdx]; + if (tid < threadsPerWarp - 1) + layoutStr += ", "; + } + layoutStr += "\n"; + } + } + } + } + return layoutStr; +} + +template +llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef s) { + auto rank = s.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(s); + return {1, s[0], s[1]}; +} + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +llvm::SmallVector +mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int rank = o.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(o); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < rank; ++i) + expanded[i] += o[i] + 1; + return expanded; +} + +std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); + + // tensorType is needed later on (e.g., getDimSize(j)), so we still have to + // pass it as a param + // TODO: Pass TensorOrMemDesc instead of RankedTensorType in + // triton-tensor-layout.cpp + if (mlir::isa(layout)) { + return getSharedLayoutStr(ll, useHWPointOfView); + } else if (mlir::isa(layout)) { + return getDistributedLayoutStr(ll, useHWPointOfView); + } + + // else unimplemented, return error + llvm::report_fatal_error("Unimplemented usage of getLayoutStr"); + return ""; +} + +void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); +} + +void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/true); +} + +namespace { +struct TensorModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; +} // namespace + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); + addInterfaces(); + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Verify that dialect attributes are attached to the right ops. + if (llvm::is_contained( + {AttrNumCTAsName, AttrTargetName, AttrNumThreadsPerWarp}, + attr.getName()) && + !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() << " which is expected only on `module` ops"; + } + if (attr.getName() == AttrNumWarpsName && !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() + << " which is expected only on `module` or `tt.func` ops"; + } + + // Verify that all ops in a tt.warp_specialize op have partition ids + if (attr.getName() == "tt.warp_specialize") { + if (!isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() << " which is expected only on `scf.for` ops"; + } + Operation *failedOp = nullptr; + op->walk([&](Operation *childOp) { + if (!childOp->hasAttr(kPartitionAttrName)) { + failedOp = childOp; + WalkResult::interrupt(); + } + }); + if (failedOp) { + return failedOp->emitOpError("does not have expected attribute ") + << kPartitionAttrName + << " which is expected on all child ops of an op with " + "attribute `tt.warp_specialize`"; + } + } + + // Verify that partition id lists are non-empty, sorted and have no duplicates + auto verifyPartitionIds = + [&](const ArrayRef &partitionIds) -> LogicalResult { + SetVector idSet; + for (auto id : partitionIds) { + if (idSet.contains(id)) + return op->emitOpError("has duplicated partition ids in attribute ") + << attr.getName(); + idSet.insert(id); + } + if (idSet.empty()) + return op->emitOpError("has no partition ids in attribute ") + << attr.getName(); + auto ids = idSet.takeVector(); + SmallVector sortedIds(ids.begin(), ids.end()); + std::sort(sortedIds.begin(), sortedIds.end()); + if (ids != sortedIds) + return op->emitOpError("partition ids not in sorted order in attribute ") + << attr.getName(); + return success(); + }; + + if (attr.getName() == kPartitionAttrName) { + auto result = verifyPartitionIds( + cast(attr.getValue()).asArrayRef()); + if (failed(result)) + return result; + } + if (attr.getName() == kPartitionOutputsAttrName) { + auto arrayAttr = cast(attr.getValue()); + for (auto idx = 0; idx < arrayAttr.size(); idx++) { + auto result = verifyPartitionIds( + cast(arrayAttr[idx]).asArrayRef()); + if (failed(result)) + return result; + } + } + + // Verify that op partitions include partitions of all child ops + if (attr.getName() == kPartitionAttrName && op->getNumRegions() != 0) { + SetVector expectedIds; + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + for (auto &childOp : block.getOperations()) { + if (isa(childOp)) { + // yield ops and ub.poison do not need partition ids + continue; + } + if (!childOp.hasAttr(kPartitionAttrName)) + return childOp.emitOpError("does not have expected attribute ") + << kPartitionAttrName + << " which is expected for ops whose parent has partitions"; + auto ids = getPartitionIds(&childOp); + expectedIds.insert(ids.begin(), ids.end()); + } + } + } + auto partitionIds = getPartitionIds(op); + for (auto id : expectedIds) { + if (!partitionIds.contains(id)) { + return op->emitOpError("partition ids in attr ") + << attr.getName() + << " does not contain partition ids of all child ops"; + } + } + } + + if (attr.getName() == kPartitionOutputsAttrName) { + if (!isa(op)) + return op->emitOpError("has unexpected attribute ") << attr.getName(); + + // Verify that number of output partitions matches number of For/If results + size_t numResults = 0; + if (isa(op)) { + numResults = cast(op).getResults().size(); + } else if (isa(op)) { + numResults = cast(op).getResults().size(); + } else { + numResults = cast(op).getResults().size(); + } + + if (cast(attr.getValue()).size() != numResults) { + return op->emitOpError("does not have expected number of output " + "partition sets in attr ") + << attr.getName() << "; should match number of results"; + } + + // Verify that union of op output partitions is a subset of op partitions + if (!op->hasAttr(kPartitionAttrName)) + return op->emitOpError("does not have expected attribute ") + << kPartitionAttrName << " which is expected for ops with attr " + << kPartitionOutputsAttrName; + auto partitionIds = getPartitionIds(op); + + SetVector outputPartitionIdsUnion; + for (auto outputPartitionIds : getPartitionOutputs(op)) { + outputPartitionIdsUnion.insert(outputPartitionIds.begin(), + outputPartitionIds.end()); + } + if (!std::all_of(outputPartitionIdsUnion.begin(), + outputPartitionIdsUnion.end(), + [&](int id) { return partitionIds.contains(id); })) { + return op->emitOpError("partition ids in attr ") + << kPartitionAttrName + << " must be the union of all partition ids in " << attr.getName(); + } + } + + return success(); +} + +int TritonGPUDialect::getNumCTAs(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumCTAsName)) + return attr.getInt(); + return 1; +} + +int TritonGPUDialect::getThreadsPerWarp(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumThreadsPerWarp)) + return attr.getInt(); + return 32; +} + +std::optional triton::gpu::maybeLookupNumWarps(Operation *op) { + if (isa(op)) { + if (auto attr = op->getAttrOfType(AttrNumWarpsName)) + return attr.getInt(); + } else if (auto partitions = + dyn_cast(op->getParentOp())) { + unsigned idx = op->getParentRegion()->getRegionNumber(); + return partitions.getParentOp().getPartitionNumWarps()[idx]; + } + if (Operation *parent = op->getParentOp()) + return maybeLookupNumWarps(parent); + return {}; +} + +int triton::gpu::lookupNumWarps(Operation *op) { + std::optional numWarps = maybeLookupNumWarps(op); + if (!numWarps) { + op->emitOpError( + "is not contained within a context that specifies the number of warps"); + llvm::report_fatal_error("failed to lookup the number of warps, the " + "surrounding module should contain a " + + Twine(AttrNumWarpsName) + " attribute"); + } + return *numWarps; +} + +int triton::gpu::lookupNumWarps(Region *region) { + if (auto partitions = + dyn_cast(region->getParentOp())) { + unsigned idx = region->getRegionNumber(); + return partitions.getParentOp().getPartitionNumWarps()[idx]; + } + return lookupNumWarps(region->getParentOp()); +} + +int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) { + assert(rewriter.getInsertionBlock() && "expected an insertion point"); + Operation *op = + rewriter.getInsertionBlock()->getParentOp()->getParentOfType(); + assert(op && "cannot check threads per warp outside of module"); + return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast(op)); +} + +int triton::gpu::lookupNumCTAs(Operation *op) { + auto mod = dyn_cast(op); + if (!mod) + mod = op->getParentOfType(); + + if (!mod) { + op->emitOpError( + "is not contained within a module, cannot lookup number of CTAs"); + llvm::report_fatal_error( + "failed to lookup the number of CTAs, the surrounding module should " + "contain a ModuleOp"); + } + return triton::gpu::TritonGPUDialect::getNumCTAs(mod); +} + +int triton::gpu::lookupNumCTAs(OpBuilder &rewriter) { + assert(rewriter.getInsertionBlock() && "expected an insertion point"); + Operation *op = + rewriter.getInsertionBlock()->getParentOp()->getParentOfType(); + assert(op && "cannot check number of CTAs outside of module"); + return triton::gpu::TritonGPUDialect::getNumCTAs(cast(op)); +} + +bool triton::gpu::areLayoutsEquivalent(ArrayRef shape, + LayoutEncodingTrait lhs, + LayoutEncodingTrait rhs) { + auto lhsLL = triton::gpu::toLinearLayout(shape, lhs); + auto rhsLL = triton::gpu::toLinearLayout(shape, rhs); + return lhsLL == rhsLL; +} + +bool triton::gpu::isInnermostContiguous(MemDescType type, unsigned numElems) { + ArrayRef shape = type.getShape(); + Attribute enc = type.getEncoding(); + MLIRContext *ctx = enc.getContext(); + + LinearLayout actual = toLinearLayout(type); + StringAttr fastestIn = *actual.getInDimNames().begin(); + + // Flatten actual outs in reverse order to produce a row-major flattening + // of the layout + auto outNames = actual.getOutDimNames(); + SmallVector revOut(outNames.begin(), outNames.end()); + std::reverse(revOut.begin(), revOut.end()); + actual = actual.transposeOuts(revOut).flattenOuts(); + + return actual.getNumConsecutiveInOut() >= numElems; +} + +LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy, + ArrayRef dstShape) { + auto *ctx = srcTy.getContext(); + auto src = toLinearLayout(srcTy); + assert(product(srcTy.getShape()) == product(dstShape)); + auto dst = reshapeLayout(ctx, src, dstShape); + return dst; +} + +FailureOr> triton::gpu::getTMABlockShape( + ArrayRef shapePerCTA, int elementBitWidth, int swizzleBytes, + bool fp4Padded, bool isTransposed, bool packedSize, + function_ref emitError) { + SmallVector blockShape(shapePerCTA); + int contigDim = isTransposed ? 0 : blockShape.size() - 1; + if (fp4Padded) + blockShape[contigDim] *= 2; + // All dimensions must be at most 256 + constexpr int64_t dimMax = 256; + for (auto &size : blockShape) + size = std::min(size, dimMax); + // Last dim must equal the swizzle byte size + if (swizzleBytes != 0) { + auto contigDimSize = (8 * swizzleBytes) / elementBitWidth; + if (blockShape[contigDim] < contigDimSize) { + return emitError() << "block shape along the contiguous dimension " + << contigDim + << " is too small for the swizzle byte size " + << swizzleBytes << " in an NVMMASharedLayout, got " + << blockShape[contigDim] << " but expected at least " + << contigDimSize; + } + blockShape[contigDim] = contigDimSize; + } + if (fp4Padded && packedSize) { + blockShape[contigDim] /= 2; + } + return blockShape; +} +SmallVector triton::gpu::getTMABlockShape( + ArrayRef shapePerCTA, int elementBitWidth, int swizzleBytes, + bool fp4Padded, bool isTransposed, bool packedSize) { + return *getTMABlockShape( + shapePerCTA, elementBitWidth, swizzleBytes, fp4Padded, isTransposed, + packedSize, []() -> InFlightDiagnostic { + llvm::report_fatal_error( + "Block shape is too small for the swizzle byte " + "size in NVMMA Shared Layout."); + }); +} + +SetVector triton::gpu::getPartitionIds(Operation *op) { + auto attrs = op->getAttr(kPartitionAttrName); + SmallVector partitionIds; + for (auto id : cast(attrs).asArrayRef()) { + partitionIds.push_back(id); + } + std::sort(partitionIds.begin(), partitionIds.end()); + return SetVector(partitionIds.begin(), partitionIds.end()); +} + +SmallVector, 4> triton::gpu::getPartitionOutputs(Operation *op) { + SmallVector, 4> partitionOutputsIds; + if (op->getNumResults() == 0) { + return partitionOutputsIds; + } + assert(op->hasAttr(kPartitionOutputsAttrName)); + auto arrayAttr = cast(op->getAttr(kPartitionOutputsAttrName)); + for (auto attr : arrayAttr) { + auto ids = cast(attr).asArrayRef(); + partitionOutputsIds.push_back(SetVector(ids.begin(), ids.end())); + } + return partitionOutputsIds; +} + +SetVector triton::gpu::getPartitionIds(OpOperand *use) { + auto owner = use->getOwner(); + if (isa(owner)) { + return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()]; + } else if (scf::ForOp forOp = dyn_cast(owner)) { + int idx = use->getOperandNumber() - forOp.getNumControlOperands(); + return idx >= 0 ? getPartitionOutputs(owner)[idx] : getPartitionIds(forOp); + } else { + return getPartitionIds(owner); + } +} + +bool triton::gpu::hasPartition(Operation *op) { + return op && op->hasAttr(kPartitionAttrName); +} + +bool triton::gpu::hasWarpSpecializeTag(Operation *op) { + return op && op->hasAttr(kWarpSpecializeTagAttrName); +} + +std::optional triton::gpu::getWarpSpecializeTag(Operation *op) { + if (hasWarpSpecializeTag(op)) { + return cast(op->getAttr(kWarpSpecializeTagAttrName)).getInt(); + } + return std::nullopt; +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 0000000000..880e53ab13 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,1975 @@ +#include + +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr; +using mlir::triton::nvidia_gpu::TensorMemoryScalesEncodingAttr; + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one CTA (one block), i.e. input dims +// [register, lane, warp] +// for register layouts, and input dims [offset] for shared layouts. +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. + +#define S(v) StringAttr::get(ctx, (v)) + +SmallVector getDefaultMmaOrder(MmaEncodingTrait layout) { + auto rank = layout.getRepOrderForOperand(0).size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + +// Returns a 1D -> ND layout by composing identity1D along each dimension +// following the specified order. +LinearLayout identityND(StringAttr inDimName, ArrayRef shape, + ArrayRef order, + ArrayRef outDimNames) { + assert(shape.size() == order.size()); + LinearLayout ret = LinearLayout::empty(); + for (unsigned i = 0; i < order.size(); ++i) { + unsigned dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +LinearLayout swizzledSharedToLinearLayout(ArrayRef shape, + SwizzledSharedEncodingAttr shared) { + MLIRContext *ctx = shared.getContext(); + + auto shapePerCTA = getShapePerCTA(shared, shape); + + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")), + shared.getCGALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shapePerCTA[colDim]; + int numRows = shapePerCTA[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int col = 1; col < numCols; col *= 2) { + bases2D.push_back({0, col}); + } + for (int row = 1; row < numRows; row *= 2) { + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= LinearLayout::identity1D(shapePerCTA[dim], S("offset"), + outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCGALayout(), shape); +} + +LinearLayout +sharedToLinearLayoutAMDRotating(ArrayRef shape, + AMDRotatingSharedEncodingAttr shared) { + MLIRContext *ctx = shared.getContext(); + + auto shapePerCTA = getShapePerCTA(shared, shape); + + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")), + shared.getCGALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shape[colDim]; + int numRows = shape[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int col = 1; col < numCols; col *= 2) { + bases2D.push_back({0, col}); + } + for (int row = 1; row < numRows; row *= 2) { + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + + int phase = (row / perPhase) % maxPhase; + int blockNo = row / maxPhase / perPhase % maxPhase; + int combinedPhase = phase ^ blockNo; + bases2D.push_back({row, (vec * combinedPhase) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCGALayout(), shape); +} + +} // namespace + +// Returns the layout of a single core matrix which tiles the nvmma layout +LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, + bool disableSwizzle) { + auto *ctx = shared.getContext(); + + int elemBitWidth = shared.getElementBitWidth(); + int tileWidthBytes = shared.getSwizzlingByteWidth(); + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + + int tileRows = 8; + int tileCols = 8 * std::max(16, tileWidthBytes) / elemBitWidth; + bool isFp4Padded = shared.getFp4Padded(); + + std::vector> bases2D; + for (int col = 1; col < tileCols; col *= 2) { + if (isFp4Padded) { + // Each group of 16 offsets consists of 8 "real" and 8 "padded" offsets. + // We represent the padded layout by mapping 8 padded offsets to the same + // coordinates as the real ones. When computing the inverse of this LL, + // the offsets correspoding to the real ones are picked in the image by + // invertAndCompose. + int colPacked = col / 16 * 8 + col % 8; + bases2D.push_back({0, colPacked}); + } else { + bases2D.push_back({0, col}); + } + } + for (int row = 1; row < tileRows; row *= 2) { + if (disableSwizzle) { + bases2D.push_back({row, 0}); + } else if (isFp4Padded) { + int colPadded = vec * ((row / perPhase) % maxPhase); + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({row, colPacked}); + } else { + bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); + } + } + auto outDimNames = standardOutDimNames(ctx, 2); + return LinearLayout({{S("offset"), bases2D}}, outDimNames); +} + +LinearLayout nvmmaSharedToLinearLayout(ArrayRef shape, + NVMMASharedEncodingAttr shared, + bool disableSwizzle) { + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + auto shapePerCTA = getShapePerCTA(shared, shape); + auto kOffset = S("offset"); + auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA, + /*packedSize=*/true); + if (shared.getSwizzlingByteWidth() == 0) { + auto outDimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset, + outDimNames[rank - 1]); + for (int i = rank - 2; i >= 0; --i) { + layout *= LinearLayout::identity1D(tmaShape[i], kOffset, outDimNames[i]); + } + layout = ensureLayoutNotSmallerThan(layout, outDimNames, shapePerCTA); + return combineCtaCgaWithShape(layout, shared.getCGALayout(), shape); + } + assert(rank >= 2); + + // Collapse all the outer dim into one. We will then create a layout for this + // shape and reshape it to the original shape. + std::array collapsedTmaShape{1, tmaShape.back()}; + for (int i = 0; i + 1 < rank; i++) + collapsedTmaShape[0] *= tmaShape[i]; + if (shared.getTransposed()) { + std::swap(collapsedTmaShape[0], collapsedTmaShape[1]); + } + + auto tileLayout = getCoreMatrixLinearLayout(shared, disableSwizzle); + auto outDimNames = standardOutDimNames(ctx, 2); + auto kRow = outDimNames[0]; + auto kCol = outDimNames[1]; + auto tileRows = tileLayout.getOutDimSize(kRow); + auto tileCols = tileLayout.getOutDimSize(kCol); + + int packingFactor = shared.getFp4Padded() ? 2 : 1; + if (collapsedTmaShape[1] * packingFactor < tileCols || + collapsedTmaShape[0] < tileRows) { + llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to " + "be at least [" + << tileRows << ", " << (tileCols / packingFactor) + << "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", " + << collapsedTmaShape[1] << "]\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + // Distribute the remaining rows and cols. + auto layout = + ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape); + + // Reshape the layout to the N-D pre-transposed shape per CTA. + SmallVector maybeTransposedTmaShape = tmaShape; + if (shared.getTransposed()) { + // Move the outer dim to the inner position. + // TODO: we should move back to using `order` instead of transposed to make + // the order more explicit. + std::rotate(maybeTransposedTmaShape.begin(), + maybeTransposedTmaShape.begin() + 1, + maybeTransposedTmaShape.end()); + } + auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape); + + if (shared.getTransposed()) { + SmallVector order = {rank - 1}; + for (int i = 0; i < rank - 1; i++) { + order.push_back(i); + } + reshapedLayout = transposeLinearLayout(reshapedLayout, order); + } + + reshapedLayout = ensureLayoutNotSmallerThan( + reshapedLayout, standardOutDimNames(ctx, shapePerCTA.size()), + shapePerCTA); + return combineCtaCgaWithShape(reshapedLayout, shared.getCGALayout(), shape); +} + +/// Function to generate lane and warp layout for dot operands. +static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, + ArrayRef shape, + ArrayRef order, + unsigned kDim, + StringAttr inDimName) { + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {1, 0} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In other words, we need to broadcast along K + auto rank = shape.size(); + auto dimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::empty(); + + // We have to broadcast along the inner dimension + // For A, when moving along M we go from 0 to 2. + // For B, when moving along N we go from 0 to 1. + // As such, choosing the order of A {1, 0}, gives us the correct broadcasting + // Same happens if the warpOrder is {0, 1}, like in Hopper + for (auto d : order) { + if (d == kDim) { + layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); + } else { + layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); + } + } + return layout; +} + +LinearLayout +AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getRank()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices, + // which will be [1, 0] / [2, 1, 0]. + SmallVector order = getDefaultMmaOrder(*this); + auto dimM = outDimNames[order[1]]; + auto dimN = outDimNames[order[0]]; + + auto mDim = getInstrShape()[0]; + auto nDim = getInstrShape()[1]; + auto elementBitWidth = getElementBitWidth(); + int height = elementBitWidth == 64 ? 1 : 4; + constexpr int warpSize = 64; + + bool isTransposed = getIsTransposed(); + // Special case for 64x4 mfma: we always transpose the output to turn + // the 64x4 mfma into a equalvalent 4x64 mfma and swap operand A and B, so + // that we can use the mfma broadcast. + if (mDim == 64 && nDim == 4) + assert(isTransposed && "64x4 mfma must be transposed"); + + int tiles = (mDim * nDim) / (warpSize * height); + + LinearLayout tileLayout = LinearLayout::empty(); + if (!isTransposed) { + // Each lane holds 'height' elements along the M dimension. + LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimM); + // First, distribute the lanes along the N dimension. + // Then, distribute the lanes along the M dimension. If the #elements + // exceeds the mDim, duplicate elements across lanes - this can happen for + // 4x4 output. + LinearLayout lanes = LinearLayout::identity1D(nDim, kLane, dimN) * + LinearLayout::identity1D(warpSize / nDim, kLane, dimM); + tileLayout = (regs * lanes); + + // Repeat the above distribution along the M dimension to fits the tile. + if (tiles > 0) + tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimM); + } else { + // For the transposed output, we will use the same method for layout but + // swap the order of the M and N dimensions. + LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimN); + LinearLayout lanes = LinearLayout::identity1D(mDim, kLane, dimM) * + LinearLayout::identity1D(warpSize / mDim, kLane, dimN); + tileLayout = (regs * lanes); + + if (tiles > 0) + tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimN); + } + + tileLayout = tileLayout.transposeOuts({dimN, dimM}); + + // Instead of defining the layout on a CTA tile and using the + // combineCtaCgaWithShape function to extend it to the whole tensor, we take a + // different approach. Suppose tilesPerWarp is 2x2—meaning a warp computes a + // 2x2 block of MFMA tiles. If we define the layout only on the CTA tile and + // extend it across the tensor, the resulting tile order won’t be N-contiguous + // (i.e., row-major). Due to the 2x2 shape, the third tile would fall in the M + // dimension. While defining the layout per CTA tile might seem more + // intuitive, the current dot op lowering assumes an N-contiguous ordering of + // MFMA tiles across the entire tensor. In other words, the lowering logic + // isn't layout-aware, it only supports a fixed N-contiguous MFMA tile + // ordering. Supporting other orderings would require extending the dot + // lowering implementation. For now, we conform to the current lowering + // algorithm by defining the MFMA linear layout globally, with N-contiguous + // tiles across the tensor and across CTA tile boundaries. + auto tilesPerWarp = getTilesPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(); + + const unsigned tilesPerWarpM = tilesPerWarp[mIndex]; + const unsigned tilesPerWarpN = tilesPerWarp[nIndex]; + const unsigned warpsPerCTAM = warpsPerCTA[mIndex]; + const unsigned warpsPerCTAN = warpsPerCTA[nIndex]; + + // First, extend the layout along the N dimension: + // - registers are distributed across tilesPerWarpN + // - then across warpsPerCTAN in the N dimension. + tileLayout *= LinearLayout::identity1D(tilesPerWarpN, kRegister, dimN); + tileLayout *= LinearLayout::identity1D(warpsPerCTAN, kWarp, dimN); + + // At this point, the layout is defined across the N dimension within a CTA + // tile. Instead of switching to the M dimension now, we continue extending + // the layout along the remaining N dimension, and only then proceed along M, + // following the tilesPerWarp configuration. + // If the N dimension is not large enough to span multiple CTA tiles (i.e., + // the first argument is 0), an empty layout is created, so this identity + // layout will not introduce any new registers. + tileLayout *= LinearLayout::identity1D( + shape[nIndex] / (nDim * warpsPerCTAN * tilesPerWarpN), kRegister, dimN); + tileLayout *= LinearLayout::identity1D(tilesPerWarpM, kRegister, dimM); + + // Finally, extend the layout across warps in the M dimension. + // After this step, the layout covers a sub-tensor of size ctaTileM × N, + // i.e., the full N dimension and a CTA tile's extent in M. + // The rest of the layout will be defined by combineCtaCgaWithShape. + tileLayout *= LinearLayout::identity1D(warpsPerCTAM, kWarp, dimM); + + // Adjust spatial ordering if batch dimension is present + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[order[2]]); + } + + return combineCtaCgaWithShape(tileLayout, getCGALayout(), shape); +} + +static LinearLayout projectAwayOutDim(const LinearLayout &layout, + StringAttr dim) { + auto ctx = layout.getOutDimNames().begin()->getContext(); + auto bases = layout.getBases(); + auto idx = layout.getOutDimIndex(dim); + for (auto inDim : layout.getInDimNames()) { + auto &inDimBases = bases[inDim]; + for (auto &basis : inDimBases) { + basis[idx] = 0; + } + } + + auto outDimNames = standardOutDimNames(ctx, layout.getOutDims().size()); + return LinearLayout(std::move(bases), outDimNames); +} + +LinearLayout chooseWmmaCTALinearLayout(MLIRContext *ctx, unsigned rank, + ArrayRef warpsPerCTA, + ArrayRef tilesPerWarp) { + StringAttr kWarp = S("warp"); + StringAttr kRegister = S("register"); + auto dims = standardOutDimNames(ctx, rank); + + auto order = getMatrixOrder(rank, /*rowMajor*/ true); + LinearLayout ret; + for (auto d : order) { + ret *= LinearLayout::identity1D(tilesPerWarp[d], kRegister, dims[d]); + ret *= LinearLayout::identity1D(warpsPerCTA[d], kWarp, dims[d]); + } + return ret.transposeOuts(dims); +} + +std::optional +chooseDotDsReadTrLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape, int32_t elemBitWidth, + unsigned instBitWidth, + unsigned numLanesInShuffleGroup) { + if (instBitWidth != 64 || numLanesInShuffleGroup != 16) + return std::nullopt; + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + auto mDim = mfmaLayout.getInstrShape()[0]; + assert(mDim == 16 || mDim == 32); + + assert(elemBitWidth == 4); + // When doing ds_read_tr4 we actually write the LL as if it were on i8 + // elements this is becasue LL needs to be described for the i8 tensor + // elements. + elemBitWidth = 8; + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int32_t kWidthDot = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch] + // For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch] + SmallVector order = + getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ false); + + std::vector> registerBase; + std::vector> laneBase; + + const bool isMfma32 = (mDim == 32); + // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look + // at i8 values for the ownership of register/lane since it's the data type + // of the tensor. Register dimension: what i8 in the tile are held by thread + // 0? Lane dimension: what i8 in the tile are held in register 0 of each + // thread? + registerBase.push_back({1, 0}); + registerBase.push_back({2, 0}); + registerBase.push_back({4, 0}); + registerBase.push_back({0, 16}); + + // If more than one tile needs to be loaded, populate registerBase + // dimension for the other tiles + const int kTileSize = isMfma32 ? 64 : 128; + for (int reg = kTileSize; reg < kSize; reg *= 2) { + registerBase.push_back({0, reg}); + } + + // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64 + // The LL for the two is different + laneBase.push_back({0, 1}); + laneBase.push_back({0, 2}); + laneBase.push_back({0, 4}); + laneBase.push_back({0, 8}); + if (mDim == 16) { + laneBase.push_back({0, 32}); + laneBase.push_back({0, 64}); + } else { + assert(mDim == 32); + laneBase.push_back({8, 0}); + laneBase.push_back({0, 32}); + } + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions we associate with register + // `order` which is also [nonk, k] given we set kContig to false. + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + auto warpOrder = getDefaultMmaOrder(mfmaLayout); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCGALayout(), shape); +} + +LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto nonKDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 2 : rank - 1; + + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto tilesPerWarp = mfmaLayout.getTilesPerWarp(); + auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex]; + + auto mDim = mfmaLayout.getInstrShape()[0]; + auto nDim = mfmaLayout.getInstrShape()[1]; + auto opIdx = dotMfmaLayout.getOpIdx(); + auto nonKDim = opIdx == 0 ? mDim : nDim; + constexpr int warpSize = 64; + + auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDimIndex]; + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + auto order = + getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ true); + auto dimK = outDimNames[order[0]]; + auto dimNonK = outDimNames[order[1]]; + + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + auto warpOrder = getDefaultMmaOrder(mfmaLayout); + + // Each lane holds kWidth elements along the K dimension + LinearLayout regs = LinearLayout::identity1D(kWidth, kRegister, dimK); + // First distribute nonKDim elements along the non-K dimension, + // then distribute remaining elements along the K dimension + LinearLayout lanes = + LinearLayout::identity1D(nonKDim, kLane, dimNonK) * + LinearLayout::identity1D(warpSize / nonKDim, kLane, dimK); + LinearLayout tileLayout = regs * lanes; + + int kTileSize = warpSize / nonKDim * kWidth; + // Special case for 4x64 and 64x4 mfma: for the 64x64 operand, + // we need to repeat the layout 16 times along the K dimension + if ((mDim == 64 && nDim == 4 && opIdx == 0) || + (mDim == 4 && nDim == 64 && opIdx == 1)) { + tileLayout *= LinearLayout::identity1D(16, kRegister, dimK); + kTileSize *= 16; + } + + // If shape K is larger than the tile size, repeat the tile + // along the K dimension. + if (kSize > kTileSize) { + tileLayout *= LinearLayout::identity1D(kSize / kTileSize, kRegister, dimK); + } + + // Follow the tiles per warp property, repeat the tile layout + // along the non-K dimension. + tileLayout *= LinearLayout::identity1D(tilePerWarpNonK, kRegister, dimNonK); + + tileLayout = tileLayout.transposeOuts({dimK, dimNonK}); + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + LinearLayout ctaLayout = tileLayout * warpLayout; + + // Note the current the output order is [k, nonk]/[k, nonk, batch]. If the + // layout's out-size is smaller than the shape, we follow this order to + // extend each dimension to match the shape. After that, we can transpose + // to match the standard output order. + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCGALayout(), shape) + .transposeOuts(outDimNames); +} + +LinearLayout AMDWmmaEncodingAttr::getTileLayout(unsigned rank) const { + assert(rank == getRank()); + assert(rank <= 3); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + auto threadOrder = getMatrixOrder(rank, /*rowMajor*/ !getIsTransposed()); + assert(threadOrder[0] == mIndex || threadOrder[0] == nIndex); + assert(threadOrder[1] == mIndex || threadOrder[1] == nIndex); + + // For wmma with 16x16 output, each of the 32 threads holds 8 elements. + // + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start at the next row. + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned version = getVersion(); + assert(version >= 1 && version <= 3 && "unexpected wmma version"); + LinearLayout tileLayout = + version == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}); + + if (hasBatchDim) { + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[0]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[0]); + } + tileLayout = tileLayout.transposeOuts(outDimNames); + return tileLayout; +} + +LinearLayout +AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto mnkDim = getInstrShape(); + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + unsigned mDim = mnkDim[0], nDim = mnkDim[1]; + (void)mDim, (void)nDim; + + assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) && + (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && + "Unsupported tensor shape for given wmma layout"); + + auto tileLayout = getTileLayout(rank); + auto ctaLayout = getCtaLayout(); + auto wmmaLayout = tileLayout * ctaLayout; + + // This output-dimension transposition is no longer required, as the + // generalized WMMA lowering makes the repetition order irrelevant. It is + // retained solely to preserve compatibility with legacy tests. + MLIRContext *ctx = getContext(); + auto defaultRepOrder = getMatrixOrder(rank, true); + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), defaultRepOrder); + + wmmaLayout = wmmaLayout.transposeOuts(repDimNames); + return combineCtaCgaWithShape(wmmaLayout, getCGALayout(), shape); +} + +static LinearLayout musaPH1WMMAToOperandLinearLayout(DotOperandEncodingAttr dot, + ArrayRef shape) { + auto mma = cast(dot.getParent()); + unsigned rank = shape.size(); + bool hasBatch = rank == 3; + MLIRContext *ctx = mma.getContext(); + auto outDimNames = standardOutDimNames(ctx, rank); + + // Operand A: [M, K], Operand B: [K, N]. + int mIndex = hasBatch ? 1 : 0; + int nIndex = hasBatch ? 2 : 1; + int kIndexA = hasBatch ? 2 : 1; + int kIndexB = hasBatch ? 1 : 0; + + unsigned instM = mma.getInstrShape()[0]; + unsigned instN = mma.getInstrShape()[1]; + unsigned instK = mma.getInstrShape()[2]; + auto warpsPerCTA = mma.getWarpsPerCTA(); + unsigned tileM = warpsPerCTA[0]; + unsigned tileN = warpsPerCTA[1]; + + StringAttr dimM = outDimNames[hasBatch ? 1 : 0]; + StringAttr dimN = outDimNames[hasBatch ? 2 : 1]; + StringAttr dimK = outDimNames[hasBatch ? 2 : 1]; + + if (dot.getOpIdx() == 0) { + // A operand: MxK + dimK = outDimNames[kIndexA]; + LinearLayout ctaLayout( + {{S("register"), {}}, + {S("lane"), {{0, 1}, {0, 2}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {dimM, dimK}); + + ctaLayout *= LinearLayout::identity1D(instK / 4, S("register"), dimK); + ctaLayout *= LinearLayout::identity1D(instM / 8, S("register"), dimM); + ctaLayout *= + LinearLayout::identity1D(shape[kIndexA] / instK, S("register"), dimK); + // Keep warp bit ordering aligned with the C layout ([N bits][M bits]). + // A consumes M-warp bits, so first absorb N-warp bits as broadcast. + ctaLayout *= LinearLayout::zeros1D(tileN, S("warp"), dimM); + ctaLayout *= LinearLayout::identity1D(tileM, S("warp"), dimM); + ctaLayout *= LinearLayout::identity1D(shape[mIndex] / instM / tileM, + S("register"), dimM); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); + } + + // B operand: KxN + dimK = outDimNames[kIndexB]; + LinearLayout ctaLayout({{S("register"), {}}, + {S("lane"), {{1, 0}, {2, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {dimK, dimN}); + + ctaLayout *= LinearLayout::identity1D(instK / 4, S("register"), dimK); + ctaLayout *= LinearLayout::identity1D(instN / 8, S("register"), dimN); + ctaLayout *= LinearLayout::identity1D(tileN, S("warp"), dimN); + // Explicitly consume M-warp bits as broadcast to keep warp-domain ordering + // consistent across A/B/C operand mappings. + ctaLayout *= LinearLayout::zeros1D(tileM, S("warp"), dimN); + ctaLayout *= LinearLayout::identity1D(shape[nIndex] / instN / tileN, + S("register"), dimN); + ctaLayout *= + LinearLayout::identity1D(shape[kIndexB] / instK, S("register"), dimK); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); +} + +static LinearLayout +musaPH1SQMMAToOperandLinearLayout(DotOperandEncodingAttr dot, + ArrayRef shape) { + auto mma = cast(dot.getParent()); + unsigned rank = shape.size(); + bool hasBatch = rank == 3; + assert(rank == 2 || rank == 3); + + MLIRContext *ctx = mma.getContext(); + auto outDimNames = standardOutDimNames(ctx, rank); + + int mIndex = hasBatch ? 1 : 0; + int nIndex = hasBatch ? 2 : 1; + int kIndexA = hasBatch ? 2 : 1; + int kIndexB = hasBatch ? 1 : 0; + + unsigned instM = mma.getInstrShape()[0]; + unsigned instN = mma.getInstrShape()[1]; + unsigned instK = mma.getInstrShape()[2]; + auto warpsPerCTA = mma.getWarpsPerCTA(); + unsigned totalWarps = std::max(1u, product(warpsPerCTA)); + + StringAttr dimM = outDimNames[mIndex]; + StringAttr dimN = outDimNames[nIndex]; + StringAttr dimK = outDimNames[hasBatch ? 2 : 1]; + + if (dot.getOpIdx() == 0) { + // A operand: MxK + dimK = outDimNames[kIndexA]; + LinearLayout ctaLayout( + {{S("register"), {}}, + {S("lane"), {{0, 1}, {0, 2}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {dimM, dimK}); + + ctaLayout *= LinearLayout::identity1D(instK / 4, S("register"), dimK); + ctaLayout *= LinearLayout::identity1D(instM / 8, S("register"), dimM); + ctaLayout *= LinearLayout::identity1D( + std::max(1, shape[kIndexA] / instK), S("register"), dimK); + ctaLayout *= LinearLayout::identity1D(totalWarps, S("warp"), dimM); + ctaLayout *= LinearLayout::identity1D( + std::max(1, shape[mIndex] / instM / totalWarps), S("register"), + dimM); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); + } + + // B operand: KxN + dimK = outDimNames[kIndexB]; + LinearLayout ctaLayout({{S("register"), {}}, + {S("lane"), {{1, 0}, {2, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {dimK, dimN}); + + ctaLayout *= LinearLayout::identity1D(instK / 4, S("register"), dimK); + ctaLayout *= LinearLayout::identity1D(instN / 8, S("register"), dimN); + ctaLayout *= LinearLayout::identity1D(totalWarps, S("warp"), dimN); + ctaLayout *= LinearLayout::identity1D( + std::max(1, shape[nIndex] / instN / totalWarps), S("register"), + dimN); + ctaLayout *= LinearLayout::identity1D( + std::max(1, shape[kIndexB] / instK), S("register"), dimK); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); +} + +static LinearLayout musaPH1WMMAToCLinearLayout(ArrayRef shape, + MUSAWmmaEncodingAttr mma) { + unsigned rank = shape.size(); + bool hasBatch = rank == 3; + MLIRContext *ctx = mma.getContext(); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr dimM = outDimNames[hasBatch ? 1 : 0]; + StringAttr dimN = outDimNames[hasBatch ? 2 : 1]; + + unsigned blockM = shape[hasBatch ? 1 : 0]; + unsigned blockN = shape[hasBatch ? 2 : 1]; + unsigned instM = mma.getInstrShape()[0]; + unsigned instN = mma.getInstrShape()[1]; + auto warpsPerCTA = mma.getWarpsPerCTA(); + unsigned tileM = warpsPerCTA[0]; + unsigned tileN = warpsPerCTA[1]; + + LinearLayout ctaLayout({{S("register"), {}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {1, 0}, {2, 0}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {dimM, dimN}); + + ctaLayout *= LinearLayout::identity1D(instN / 8, S("register"), dimN); + ctaLayout *= LinearLayout::identity1D(instM / 4, S("register"), dimM); + ctaLayout *= LinearLayout::identity1D(tileN, S("warp"), dimN); + ctaLayout *= + LinearLayout::identity1D(blockN / instN / tileN, S("register"), dimN); + ctaLayout *= LinearLayout::identity1D(tileM, S("warp"), dimM); + ctaLayout *= + LinearLayout::identity1D(blockM / instM / tileM, S("register"), dimM); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); +} + +static LinearLayout musaPH1SQMMAToCLinearLayout(ArrayRef shape, + MUSASqmmaEncodingAttr mma) { + unsigned rank = shape.size(); + bool hasBatch = rank == 3; + assert(rank == 2 || rank == 3); + + MLIRContext *ctx = mma.getContext(); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr dimM = outDimNames[hasBatch ? 1 : 0]; + StringAttr dimN = outDimNames[hasBatch ? 2 : 1]; + + unsigned instM = mma.getInstrShape()[0]; + unsigned n = mma.getInstrShape()[1]; + + auto warpsPerCTA = mma.getWarpsPerCTA(); + unsigned staticDim0Size = 0; + unsigned staticDim1Size = 0; + + LinearLayout ctaLayout( + {{S("register"), {}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {0, 1}, {0, 2}}}}, + {dimN, dimM}); + + staticDim0Size = ctaLayout.getOutDimSize(dimM); + staticDim1Size = ctaLayout.getOutDimSize(dimN); + + ctaLayout *= + LinearLayout::identity1D(n / staticDim1Size, S("register"), dimN); + + if (warpsPerCTA[0] / 4 != 0) { + ctaLayout *= LinearLayout::identity1D(4, S("warp"), dimM); + } + + // Keep the explicit 4-row squad strip in the warp basis. With logical M in + // instrShape, only the register repetition factor shrinks relative to the + // old encoded M/4 representation. + ctaLayout *= LinearLayout::identity1D(instM / (4 * staticDim0Size), + S("register"), dimM); + + ctaLayout *= identityND(S("warp"), {warpsPerCTA[0] / 4, warpsPerCTA[1]}, + /*order=*/{0, 1}, {dimM, dimN}) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + if (hasBatch) { + ctaLayout *= LinearLayout::identity1D(1, S("register"), outDimNames[0]); + ctaLayout *= LinearLayout::identity1D(1, S("lane"), outDimNames[0]); + } + + return combineCtaCgaWithShape(ctaLayout, mma.getCGALayout(), shape); +} + +LinearLayout +MUSAWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + assert(isPH1() && "unsupported MUSA WMMA encoding version"); + return musaPH1WMMAToCLinearLayout(shape, *this); +} + +LinearLayout +MUSASqmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + assert(isPH1() && "unsupported MUSA SQMMA encoding version"); + return musaPH1SQMMAToCLinearLayout(shape, *this); +} + +LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, + ArrayRef shape) { + auto wmmaLayout = llvm::cast(dotWmmaLayout.getParent()); + unsigned version = wmmaLayout.getVersion(); + assert(version >= 1 && version <= 3 && "unexpected wmma version"); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + + MLIRContext *ctx = dotWmmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // lane order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + auto order = + getOrderForDotOperand(dotWmmaLayout.getOpIdx(), rank, /*kContig*/ true); + auto dimK = outDimNames[order[0]]; + auto dimNonK = outDimNames[order[1]]; + + auto mnkDim = wmmaLayout.getInstrShape(); + auto kDim = mnkDim[2]; + auto nonKDimIndex = dotWmmaLayout.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto kDimIndex = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + unsigned kSize = shape[kDimIndex]; + + auto nonKDim = dotWmmaLayout.getOpIdx() == 0 ? mnkDim[0] : mnkDim[1]; + auto kWidth = dotWmmaLayout.getKWidth(); + constexpr int warpSize = 32; + + // The relative order of registers and lanes is given by: + // - k dim: kWidth registers + // - non-k dim: nonKDim lanes + // - k dim: depth = warpSize / nonKDim lanes + // version 1 duplicates these values across k dim + // version 2/3 offsets these values across k dim + // - k dim: repeat kDim / (kWidth * depth) times to fit k dim + LinearLayout tileLayout; + int depth = warpSize / nonKDim; + tileLayout = LinearLayout::identity1D(kWidth, kRegister, dimK) * + LinearLayout::identity1D(nonKDim, kLane, dimNonK); + tileLayout *= version == 1 ? LinearLayout::zeros1D(depth, kLane, dimK) + : LinearLayout::identity1D(depth, kLane, dimK); + + int kTileSize = depth * kWidth; + tileLayout *= LinearLayout::identity1D(std::max(kSize, kDim) / kTileSize, + kRegister, dimK); + + auto ctaLayout = wmmaLayout.getCtaLayout(); + // Zero out M or N dim based on opIdx + ctaLayout = projectAwayOutDim(ctaLayout, dimK); + // If repetition (aka register basis) iz 0 in all out dims we need to remove + // it since this repetition doesn't make sense for dotOp layout. + ctaLayout = actionRemoveBroadcastedRegs(ctaLayout).apply(ctaLayout); + + LinearLayout dotOperanLayout = tileLayout * ctaLayout; + + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), order); + dotOperanLayout = dotOperanLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(dotOperanLayout, wmmaLayout.getCGALayout(), + shape); +} + +LinearLayout +BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); + auto order = getOrder(); + LinearLayout ctaLayout = + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); + + return combineCtaCgaWithShape(ctaLayout, getCGALayout(), shape); +} + +LinearLayout fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, + ArrayRef shape) { + int rank = shape.size(); + auto blocked = cast(operandLayout.getParent()); + MLIRContext *ctx = operandLayout.getContext(); + + // TODO: introduce registerOrder or use getDefaultOrder(operandLayout) + // Currently this order is used in legacy converter, because we do not + // have access to full dot operand layout, only parent part. + auto regOrder = blocked.getOrder(); + auto threadOrder = blocked.getOrder(); + auto warpOrder = blocked.getOrder(); + auto repOrder = blocked.getRepOrder(); + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + auto threadSize = llvm::to_vector(blocked.getSizePerThread()); + auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + threadSize[kDimIdx] = shape[kDimIdx]; + auto threadShape = blocked.getThreadsPerWarp(); + auto warpShape = blocked.getWarpsPerCTA(); + + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), repOrder); + + auto registersLayout = identityStandardND(kReg, threadSize, regOrder); + auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder, + kDimIdx, kLane); + auto warpsLayout = + broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp); + + LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) * + lanesLayout.transposeOuts(repDimNames) * + warpsLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCGALayout(operandLayout), shape); +} + +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder) { + // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + // Like LinearLayout::empty() but with a rank and an order + int rank = repOrder.size(); + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout ctaLayout = + identityStandardND(S("register"), trivialShape, repOrder); + + assert(rank >= 2); + auto inner = order[0]; + auto outer = order[1]; + + assert(tileShape.size() == rank); + int m = tileShape[outer]; + int n = tileShape[inner]; + + // The relative order of registers and lanes is given by: + // - Inner dim: kWidth registers + // - Inner dim: 4 lanes + // - Outer dim: 8 lanes + // - Outer dim: repeat m / 8 times + // - Inner dim: repeat n / (kWidth * 4) times + assert(m % 8 == 0); + assert(n % (kWidth * 4) == 0); + // There is at least one subtile on the inner-most dimension + // FIXME. We should implement operator* in terms of operator*= + // and chain *= instead of using * + auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames()); + ctaLayout = ctaLayout * + LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) * + LinearLayout::identity1D(4, S("lane"), dimNames[inner]) * + LinearLayout::identity1D(8, S("lane"), dimNames[outer]) * + LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) * + LinearLayout::identity1D(n / (kWidth * 4), S("register"), + dimNames[inner]); + return ctaLayout; +} + +LinearLayout +NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ctx = getContext(); + int rank = shape.size(); + assert(rank == getRank()); + + SmallVector tileShape; + if (isAmpere()) { + // Ampere.getInstrShape() returns the tile shape + tileShape = SmallVector(getInstrShape()); + } else { + assert(isHopper()); + auto instrShapeMNK = getInstrShape(); + tileShape = SmallVector({instrShapeMNK[0], instrShapeMNK[1]}); + } + // nvidiamma layout always assumes kWidth = 2 + constexpr auto kWidth = 2; + auto order = getDefaultMmaOrder(*this); + auto ctaLayout = nvidiaMmaTile(ctx, tileShape, kWidth, order, getRepOrder()); + + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ !isHopper()); + ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCGALayout(), shape); +} + +LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + MLIRContext *ctx = mma.getContext(); + + SmallVector tileShape(rank, 1); + if (isA) { + tileShape[rank - 2] = 16; + tileShape[rank - 1] = kWidth * 8; + } else { + // Hopper takes the rhs via shared memory + assert(mma.isAmpere()); + tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 1] = 8; + } + auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true); + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, order, dot.getRepOrder()); + auto kDim = isA ? rank - 1 : rank - 2; + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ !mma.isHopper()); + ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), warpOrder, + kDim, S("warp")) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCGALayout(dot), shape); +} + +LinearLayout +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto parent = getParent(); + if (auto blockedLayout = mlir::dyn_cast(parent)) { + return fmaDotToLinearLayout(*this, shape); + } else if (auto mfmaLayout = mlir::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { + return wmmaDotOperandToLinearLayout(*this, shape); + } else if (auto musaWmma = mlir::dyn_cast(parent)) { + return musaPH1WMMAToOperandLinearLayout(*this, shape); + } else if (auto musaSqmma = mlir::dyn_cast(parent)) { + return musaPH1SQMMAToOperandLinearLayout(*this, shape); + } else { + auto mma = mlir::cast(parent); + return nvidiaDotToLinearLayout(shape, *this); + } +} + +LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + getDim(), 1); + LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); + + auto sliceLL = removeStandardDim(parentLL, getDim()); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + return LinearLayout(std::move(bases), + llvm::to_vector(sliceLL.getOutDimNames())); +} + +LinearLayout tensorMemoryToLinearLayout(ArrayRef shape, + TensorMemoryEncodingAttr encoding) { + // [Zeros in TMEM LinearLayouts] + // If there is a zero in bases rows=32,64 this means that there is + // broadcasting, i.e. the same tensor element is duplicated in different + // addressable blocks If the zero is in any other row/col (i.e. within a given + // warp-addressable tmem space) it means it is not defined + + // We model packed layouts as having the rows/cols dimensions of bitWidth=16 + // This means that a layout with unpacked=True is the same as one with + // unpacked=False + assert(shape.size() == 2); + auto *ctx = encoding.getContext(); + auto kRow = S("row"); + auto kCol = S("col"); + auto dims = standardOutDimNames(ctx, 2); + // The CTAOrder = [0, 1] so se start by N so that it ends up as + // ((tile * splitM) * splitN) + if (encoding.getCTASplitN() > 1) { + auto split = + LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]); + auto newEncoding = TensorMemoryEncodingAttr::get( + ctx, encoding.getBlockM(), encoding.getBlockN(), + encoding.getColStride(), encoding.getCTASplitM(), 1, + encoding.getTwoCTAs()); + return tensorMemoryToLinearLayout( + {shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) * + split; + } + if (encoding.getCTASplitM() > 1) { + auto splitM = encoding.getCTASplitM(); + auto blockM = encoding.getBlockM(); + bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs(); + if (isM64TwoCTA) { + // blockM == 64 and twoCTAs is laid out as the transpose of 128xblockN + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-b + blockM *= 2; + splitM /= 2; + } + auto newEncoding = TensorMemoryEncodingAttr::get( + ctx, blockM, encoding.getBlockN(), encoding.getColStride(), 1, + encoding.getCTASplitN(), encoding.getTwoCTAs()); + auto ret = + tensorMemoryToLinearLayout({shape[0] / splitM, shape[1]}, newEncoding); + // In this case, we swap the basis of the last row and last column + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-bny + if (isM64TwoCTA) { + auto bases = ret.getBases(); + auto basisCTA1 = + llvm::Log2_32(encoding.getBlockN() * encoding.getColStride()) - 1; + std::swap(bases[kRow].back(), bases[kCol][basisCTA1]); + ret = + LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective()); + } + auto split = LinearLayout::identity1D(splitM, kCol, dims[0]); + return ret * split; + } + assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1); + + auto blockM = encoding.getBlockM(); + auto blockN = std::min(encoding.getBlockN(), shape[1]); + assert(blockM == 64 || blockM == 128); + LinearLayout tile = + LinearLayout::zeros1D(encoding.getColStride(), kCol, dims[1]); + if (blockM == 64) { + tile *= LinearLayout::identity1D(16, kRow, dims[0]) * + LinearLayout::identity1D(blockN, kCol, dims[1]); + auto bases = tile.getBases(); + if (shape[0] > blockM) { + bases[kRow].push_back({64, 0}); + } else if (shape[1] > blockN) { + bases[kRow].push_back({0, blockN}); + } else { + // Empty, meaning the element is not defined + bases[kRow].push_back({0, 0}); + } + bases[kRow].push_back({16, 0}); + bases[kRow].push_back({32, 0}); + tile = LinearLayout(std::move(bases), dims); + } else { + tile *= LinearLayout::identity1D(blockM, kRow, dims[0]) * + LinearLayout::identity1D(blockN, kCol, dims[1]); + } + auto repsM = shape[0] / tile.getOutDimSize(dims[0]); + auto repsN = shape[1] / tile.getOutDimSize(dims[1]); + assert(repsM >= 1 && repsN >= 1); + // Broadcast the remaining dimensions in order [0, 1] + tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) * + LinearLayout::identity1D(repsN, kCol, dims[1]); + return tile; +} + +LinearLayout +tensorMemoryScalesToLinearLayout(ArrayRef shape, + TensorMemoryScalesEncodingAttr encoding) { + assert(shape.size() == 2); + auto *ctx = encoding.getContext(); + auto kRow = S("row"); + auto kCol = S("col"); + auto dims = standardOutDimNames(ctx, 2); + + // The CTAOrder = [0, 1] so se start by N so that it ends up as + // ((tile * splitM) * splitN) + if (encoding.getCTASplitN() > 1) { + auto split = + LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]); + auto newEncoding = + TensorMemoryScalesEncodingAttr::get(ctx, encoding.getCTASplitM(), 1); + return tensorMemoryScalesToLinearLayout( + {shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) * + split; + } + if (encoding.getCTASplitM() > 1) { + auto split = + LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]); + auto newEncoding = + TensorMemoryScalesEncodingAttr::get(ctx, 1, encoding.getCTASplitN()); + return tensorMemoryScalesToLinearLayout( + {shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) * + split; + } + assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1); + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + auto tile = LinearLayout::identity1D(32, kRow, dims[0]) * + // Broadcasting along 'warps' + LinearLayout::zeros1D(4, kRow, dims[0]) * + LinearLayout::identity1D(4, kCol, dims[1]) * + LinearLayout::identity1D(2, kCol, dims[0]); + // We choose repOrder = [0, 1] + tile *= LinearLayout::identity1D( + llvm::divideCeil(shape[0], tile.getOutDimSize(dims[0])), kCol, + dims[0]) * + LinearLayout::identity1D( + llvm::divideCeil(shape[1], tile.getOutDimSize(dims[1])), kCol, + dims[1]); + // See [Zeros in TMEM LinearLayouts] + // Set some rows/cols to 0 if shape is smaller than 64 x 4 + llvm::SmallDenseMap shapeMap; + for (auto [dim, size] : llvm::zip(dims, shape)) { + shapeMap[dim] = size; + } + return ensureLayoutNotLargerThan(tile, shapeMap); +} + +LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef shape, + Attribute layout) { + CacheKey key{std::vector(shape.begin(), shape.end()), layout}; + if (auto result = llCache.get(key)) { + return *result; + } + + // Layouts are distributed or shared in triton core + // To add a new layout add an else-if clause + LinearLayout result = LinearLayout::empty(); + if (auto distributed = dyn_cast(layout)) { + result = distributed.toLinearLayout(shape); + } else { + assert(llvm::all_of(shape, + [](int64_t dim) { + return llvm::isPowerOf2_32(dim) && dim >= 1; + }) && + "shape must be a postive power of 2"); + if (auto shared = dyn_cast(layout)) { + result = swizzledSharedToLinearLayout(shape, shared); + } else if (auto shared = dyn_cast(layout)) { + result = shared.toLinearLayout(shape); + } else if (auto shared = dyn_cast(layout)) { + result = nvmmaSharedToLinearLayout(shape, shared); + } else if (auto sbl = dyn_cast(layout)) { + result = sharedToLinearLayoutAMDRotating(shape, sbl); + } else if (auto tensorMemoryEncoding = + dyn_cast(layout)) { + result = tensorMemoryToLinearLayout(shape, tensorMemoryEncoding); + } else if (auto tensorMemoryScalesEncoding = + dyn_cast(layout)) { + result = + tensorMemoryScalesToLinearLayout(shape, tensorMemoryScalesEncoding); + } else { + assert(0 && "unknown layout"); + } + } + + llCache.set(std::move(key), result); + return result; +} + +LinearLayout toLinearLayout(RankedTensorType type) { + return toLinearLayout(type.getShape(), type.getEncoding()); +} + +LinearLayout toLinearLayout(MemDescType type) { + // Pass in the allocation shape. Then when using invertAndCompose it will + // trim the allocationShape to the shape if they are different. + // We also remove the first dimension of the allocationShape if there was a + // call to memdesc_index + auto shape = type.getAllocShape().take_back(type.getRank()); + return toLinearLayout(shape, type.getEncoding()); +} + +LinearLayout toLinearLayout(TensorOrMemDesc type) { + if (auto ranked = dyn_cast(type)) { + return toLinearLayout(ranked); + } else { + auto memDesc = cast(type); + return toLinearLayout(memDesc); + } +} + +// UNSAFE OVERLOAD! +// If you call this with a SharedMemoryEncodingAttr, you should call it +// with the allocShape as the shape, otherwise the layout will be incorrect! +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearLayout(shape, + layout); +} + +LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { + assert(!layout.getInDimNames().empty()); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + + StringAttr kBlock = S("block"); + assert(layout.hasInDim(kBlock)); + auto bases = layout.getBases(); + bases[kBlock] = {}; + return LinearLayout(std::move(bases), + llvm::to_vector<4>(layout.getOutDimNames())); +} + +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CGAEncodingAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(cgaLayoutAttr.getLinearLayout(), labeledShape) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(llvm::to_vector(ctaLayout.getOutDimNames()) == + llvm::to_vector(cgaLayout.getOutDimNames())); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order) { + auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); + LinearLayout layout = LinearLayout::empty(); + SmallVector kRepDims; + SmallVector kOffsetDims; + auto totalIters = 1; + auto totalOffsets = 1; + for (int i = 0; i < tensorShape.size(); i++) { + int dim = order[i]; + StringAttr kIteration = S("iteration" + std::to_string(dim)); + StringAttr kOffset = S("offset" + std::to_string(dim)); + kRepDims.push_back(kIteration); + kOffsetDims.push_back(kOffset); + assert(llvm::isPowerOf2_32(repShape[dim])); + assert(llvm::isPowerOf2_32(tensorShape[dim])); + auto numIters = tensorShape[dim] / repShape[dim]; + layout *= + LinearLayout::identity1D(repShape[dim], kOffset, outDimNames[dim]); + layout *= LinearLayout::identity1D(numIters, kIteration, outDimNames[dim]); + totalIters *= numIters; + totalOffsets *= repShape[dim]; + } + StringAttr kOffset = S("offset"); + StringAttr kIteration = S("iteration"); + StringAttr kBlock = S("block"); + SmallVector newDims; + newDims.append(kOffsetDims.begin(), kOffsetDims.end()); + newDims.append(kRepDims.begin(), kRepDims.end()); + // Transpose layout from [offset0, rep0, offset1, rep1, ...] to + // [offset0, offset1, ..., rep0, rep1, ...] + auto ret = layout.transposeIns(newDims); + // Reshape layout from [offset0, offset1, ..., rep0, rep1, ...] to + // [offset, rep, block] + return ret.reshapeIns( + {{kOffset, totalOffsets}, {kIteration, totalIters}, {kBlock, 1}}); +} + +std::optional +chooseDsReadTrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth, unsigned instBitWidth, + unsigned numLanesInShuffleGroup) { + assert(elemBitWidth == 4); + auto dot = cast(enc); + return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth, + numLanesInShuffleGroup); +} + +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef dotOperandShape, + unsigned wmmaMDim, + LinearLayout ctaLayout) { + using basisT = std::vector>; + unsigned rank = dotOperandShape.size(); + SmallVector order; + if (rank == 3) { + order = {1, 0, 2}; + } else { + order = {1, 0}; + } + auto outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + // In scaled dot, the shapes of operands(without batch dimension) are, + // respectively: + // - A: [M, K] + // - B: [K, N] + // - aScale: [M, K / 32 or 16] + // - bScale: [N, K / 32 or 16] + auto dimK = outDimNames[order[0]]; + auto dimNonK = outDimNames[order[1]]; + + // Each lane holds kWidth=4 consecutive values along the K dim. + // The first 16 lanes are distributed along the nonK dim. + unsigned scaleKWidth = 4; + auto kSize = dotOperandShape[1]; + LinearLayout tileLayout = + LinearLayout::identity1D(scaleKWidth, kRegister, dimK) * + LinearLayout::identity1D(16, kLane, dimNonK) * + LinearLayout::zeros1D(2, kLane, dimNonK); + + unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1; + + // If the shape along the K dim is larger than kWidth, repeat this + // pattern to fill the K dim. + tileLayout *= LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK); + + if (dotOperandIdx == 1) { + ctaLayout = transposeLinearLayout(ctaLayout, order); + } + + // Zero out M or N dim based on opIdx + ctaLayout = projectAwayOutDim(ctaLayout, dimK); + // If repetition (aka register basis) iz 0 in all out dims we need to remove + // it since this repetition doesn't make sense for dotOp layout. + ctaLayout = actionRemoveBroadcastedRegs(ctaLayout).apply(ctaLayout); + + ctaLayout = tileLayout.transposeOuts(outDimNames) * ctaLayout; + auto nonOpSelLayout = combineCtaCgaWithShape( + ctaLayout, CGAEncodingAttr::get1CTALayout(ctx, /*rank=*/2), + dotOperandShape); + + // This is the tricky part. For a single tile, only 16 threads + // hold scale values, 4 for each thread. Other 16 thread in a warp + // broadcast these values. This is a waste of memory. In order to deal with + // that we can assignd other 16 threads (thread 15-31), to hold scales of the + // next tile computed by the same warp (aka it's first repetition in non-k + // dim), if there is one. So register base that naturally represents first + // repetition needs to be moved to lane base that represents lane 16. Since + // for a single tile thread holds 4 vals, we move register base 2, to lane + // base 4. + + // No repetitions in m/n dim. + auto firstRepInNonK = tileLayout.getInDimSizeLog2(kRegister); + if (nonOpSelLayout.getInDimSizeLog2(kRegister) <= firstRepInNonK) { + return nonOpSelLayout; + } + + // We want to "move" the register basis (index firstRepInNonK) + // into the fifth lane basis slot (index 4), if present. + constexpr int kLaneInsertIndex = 4; + auto bases = nonOpSelLayout.getBases(); + std::swap(bases[kRegister][firstRepInNonK], bases[kLane][kLaneInsertIndex]); + bases[kRegister].erase(bases[kRegister].begin() + firstRepInNonK); + + return LinearLayout(std::move(bases), outDimNames); +} + +// PTX ISA - Warp-level MMA Block Scaling +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling +// This function generates layouts for scale tensors used in scaled dot +// operations. +// Implementation notes: +// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b = +// 0) +// - We choose a fixed byte selector for A (byte-id-a = 0) and B (byte-id-b = +// 0) +// - Each lane in a quad has the same scale factor. +LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, + ArrayRef shape, int opIdx, + ArrayRef warpsPerCTA, + CGAEncodingAttr cgaLayout) { + unsigned rank = shape.size(); + auto outDims = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + // - A: [M, K] + // - B: [K, N] + // - aScale: [M, K / K_GROUP_SIZE] + // - bScale: [N, K / K_GROUP_SIZE] + const unsigned kIdx = 1; + const unsigned mnIdx = 0; + + std::vector> laneBase; + SmallVector order; + SmallVector mmaWarpsPerCTA; + if (opIdx == 0) { + laneBase = {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}; + order = SmallVector{1u, 0u}; + mmaWarpsPerCTA = SmallVector{warpsPerCTA[0], warpsPerCTA[1]}; + } else { + laneBase = {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}; + order = SmallVector{0u, 1u}; + mmaWarpsPerCTA = SmallVector{warpsPerCTA[1], warpsPerCTA[0]}; + } + LinearLayout LL = + LinearLayout::identity1D(shape[1], kRegister, outDims[kIdx]) * + LinearLayout({{kLane, laneBase}}, {outDims[mnIdx], outDims[kIdx]}) * + broadcastedDotOperandLayout(ctx, mmaWarpsPerCTA, order, 1u, kWarp); + return combineCtaCgaWithShape(LL, cgaLayout, shape); +} + +LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef dotOperandShape, + unsigned mfmaMDim, + ArrayRef tilesPerWarp, + ArrayRef warpsPerCTA) { + using basisT = std::vector>; + unsigned rank = dotOperandShape.size(); + auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); + auto standardOutDims = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + // Fetch the tilesPerWarp value in the M dimension for operand A, or in the N + // dimension for operand B. + unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1; + unsigned tilePerWarpMN = tilesPerWarp[mnDim]; + + // In scaled dot, the shapes of operands(without batch dimension) are, + // respectively: + // - A: [M, K] + // - B: [K, N] + // - aScale: [M, K / 32] + // - bScale: [N, K / 32] + // + // In general, for both 32x32 and 16x16 scaled mfma, and no matter what + // data type the A/B operand is, each lane takes 32 elements from A/B + // alone K dim, and 1 or 2 elements from scale accordingly. The number of + // scale's elements in a lane varies because the 32 elements from A/B may + // not be consecutive. + // + // For mxfp4, these 32 elements are consecutive, so only 1 scale element + // is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements + // blocks, so 2 scale elements are required. + int32_t kSize = dotOperandShape[1]; + + std::vector> registerBase; + std::vector> laneBase; + + auto threadsInKDim = mfmaMDim == 32 ? 2 : 4; + for (int32_t elem = threadsInKDim; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + for (int32_t elem = mfmaMDim; elem < tilePerWarpMN * mfmaMDim; elem *= 2) + registerBase.emplace_back(std::vector{0, elem}); + + if (mfmaMDim == 32) { + // For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes + // collectively handle A[0:32][32:64]. Each lane take 1 scale element + // accordingly. Similar to B and bScale. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}; + } else { + assert(mfmaMDim == 16); + // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes + // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale + // element accordingly. Similar to B and bScale. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}; + } + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + SmallVector warpsPerCTANew = + (dotOperandIdx == 1) + ? SmallVector{warpsPerCTA[1], warpsPerCTA[0]} + : SmallVector{warpsPerCTA[0], warpsPerCTA[1]}; + + SmallVector warpOrder = (dotOperandIdx == 1) + ? SmallVector{0, 1} + : SmallVector{1, 0}; + + LinearLayout warpLayout = + identityStandardND(kWarp, warpsPerCTANew, warpOrder); + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + + auto cgaLayout = CGAEncodingAttr::get1CTALayout(ctx, 2); + auto finalLay = combineCtaCgaWithShape(ctaLayout, cgaLayout, dotOperandShape); + return finalLay; +} + +std::optional +chooseMfmaLikeStoreLayout(RankedTensorType valType) { + // TODO: WMMA Support on RDNA + if (!isa(valType.getEncoding())) + return {}; + auto mfmaLayout = cast(valType.getEncoding()); + + // We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on + // CDNA4. + auto mnkDim = mfmaLayout.getInstrShape(); + bool isMfma32 = mnkDim[0] == 32 && mnkDim[1] == 32; + bool isMfma16 = mnkDim[0] == 16 && mnkDim[1] == 16; + + auto valShape = valType.getShape(); + // For mfma16x16, to use in-wavefront swap, we need to make sure the tiles + // used are in one wavefront if there are multiple tiles, which means + // warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For + // now, it is only possible for FA-like kernels since during mfma generation, + // the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs, + // 1]. + // TODO: For gemm-like kernel, the transformation here cannot be applied for + // now and will support it. + bool validForMfma16 = isMfma16 && valShape.back() >= 16 * 2 && + mfmaLayout.getWarpsPerCTA().back() == 1; + + Type elemType = valType.getElementType(); + if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && + mfmaLayout.getVersion() == 4 && mfmaLayout.getIsTransposed() && + (isMfma32 || validForMfma16))) + return {}; + + LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape); + auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames()); + StringAttr dimM = mfmaOutDims[0]; + StringAttr dimN = mfmaOutDims[1]; + auto swapLL = LinearLayout::empty(); + // The rows are kept as is with an identity linear layout. + swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM); + /* + clang-format off + In transposed mfma32 layout, Each thread holds 4 consecutive values along N + dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column + 8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8 + elements. This would mean exchange the 2nd and 3rd basis vector from an + identity linear layout on tensor elements. + + Correspondingly, the transposed mfma16 layout, the output of + transposed of mfma16x16 is: + + N/register + M/Lane v0 v1 v2 v3 v4 v5 v6 v7 + ------------------------------------------------------------------------- + row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + which means: + The columns from v0 to v3 are in the one output of mfma16x16 and + the columns from v4 to v7 are in the one output of mfma16x16, + + The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor, + N/register + ----------------------------------------------- + M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) | + |.... | sub-tensor-0 | + |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) | + ----------------------------------------------- + |(0, 4) ... (0, 7) | (0, 20) ... (0, 23) | + |sub-tensor-1 | .... | + |(15, 0) ... (15, 3) | (15, 20) ... (15, 23) | + ----------------------------------------------- + |(0, 8) ... (0, 11)| (0, 24) ... (0, 27) | + |.... | sub-tensor-2 | + |(15, 8) ... (15, 11)| (15, 24) ... (15, 27) | + ----------------------------------------------- + |(0, 12) ... (0, 15)| (0, 28) ... (0, 31) | + |sub-tensor-3 | .... | + |(15, 12) ... (15, 15)| (15, 28) ... (15, 31) | + ----------------------------------------------- + The basis vector for lane and register are: + Register = {{0, 1}, {0, 2}} + Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}} + With this layout, only 4xfp16 can be packed in the final global store. + + To use 128-bits global store, we need to pack 8 elements, which means the layout looks like: + N/register + M/Lane v0 v1 v2 v3 v4 v5 v6 v7 + ------------------------------------------------------------------------- + row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | + ------------------------------------------------------------------------- + row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | + ------------------------------------------------------------------------- + row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + + The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor: + N/register + ----------------------------------------------- + |(0, 0) ... (0, 3) | (0, 4) ... (0, 7) | + |.... | sub-tensor-1 | + |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) | + ----------------------------------------------- + |(0, 16) ... (0, 19) | (0, 20) ... (0, 23) | + |sub-tensor-0 | .... | + |(15, 16) ... (15, 19)| (15, 20) ... (15, 23) | + ----------------------------------------------- + |(0, 8) ... (0, 11)| (0, 12) ... (0, 15) | + |.... | sub-tensor-3 | + |(15, 8) ... (15, 11)| (15, 12) ... (15, 15) | + ----------------------------------------------- + |(0, 24) ... (0, 27)| (0, 28) ... (0, 31) | + |sub-tensor-2 | .... | + |(15, 24) ... (15, 27)| (15, 28) ... (15, 31) | + ----------------------------------------------- + which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3. + And basis vector for lane and register are: + Register = {{0, 1}, {0, 2}, {0, 4}} + Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}} + + The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16. + Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with + the original mfma16 LL. + clang-format on + */ + auto destIdxInBases = isMfma32 ? 3 : 4; + std::vector> dimNBases(mfmaLL.getOutDimSizeLog2(dimN)); + std::generate(dimNBases.begin(), dimNBases.end(), + [i = 0]() mutable { return std::vector{1 << i++}; }); + std::swap(dimNBases[2], dimNBases[destIdxInBases]); + swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN}); + + return mfmaLL.compose(swapLL); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp new file mode 100644 index 0000000000..926a456c40 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -0,0 +1,1412 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/DebugStringHelper.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" + +// Provide custom directive handlers for declarative assemblyFormat. +// They must be visible before including the generated op classes. +static mlir::ParseResult parseOffsets(mlir::OpAsmParser &p, + mlir::DenseI32ArrayAttr &attr) { + llvm::SmallVector values; + if (p.parseCommaSeparatedList([&]() { + int32_t v; + if (p.parseInteger(v)) + return mlir::failure(); + values.push_back(v); + return mlir::success(); + })) + return mlir::failure(); + attr = p.getBuilder().getDenseI32ArrayAttr(values); + return mlir::success(); +} + +static void printOffsets(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::DenseI32ArrayAttr attr) { + auto vals = attr.asArrayRef(); + llvm::interleaveComma(vals, p, [&](int32_t v) { p << v; }); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::gpu { + +namespace { + +template bool hasEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) { + auto encoding = tensorType.getEncoding(); + return encoding && isa(encoding); + } + return false; +} + +bool hasDotOperandEncoding(Value value) { + return hasEncoding(value); +} + +bool isConvertTrivial(ConvertLayoutOp op) { + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + auto srcEncoding = srcType.getEncoding(); + auto dstEncoding = dstType.getEncoding(); + return cast(&srcEncoding.getDialect()) + ->verifyLayoutsAreEqual(srcType.getShape(), srcEncoding, dstEncoding, {}) + .succeeded(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +// tmem_store(cvt) -> tmem_store +struct CanonicalizeConvertFromTMEMStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(nvidia_gpu::TMEMStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + + // bail for incompatible layouts + auto cvtSrcType = convert.getSrc().getType(); + if (!nvidia_gpu::isDistributedLayoutTMemCompatible( + op.getOperation(), cvtSrcType, op.getDst().getType())) { + return failure(); + } + + rewriter.modifyOpInPlace( + op, [&]() { op.getSrcMutable().assign(convert.getSrc()); }); + return mlir::success(); + } +}; + +// reshape(cvt) -> reshape +struct CanonicalizeConvertFromReshape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + // If the layouts are structurally the same, the convert is trivial + if (isConvertTrivial(convert)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder(), + op.getEfficientLayout()); + return success(); + } + + if (isExpensiveView(convert.getSrc().getType(), op.getType())) + return failure(); + if (!op.getAllowReorder()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder(), + op.getEfficientLayout()); + return mlir::success(); + } +}; + +// TODO We should do this generically for op(cvt) -> op +// We have similar patterns for reshape and split... +// See https://github.com/triton-lang/triton/pull/5403#discussion_r1920091671 + +// trans(cvt) -> trans +struct CanonicalizeConvertFromTranspose + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::TransOp op, + PatternRewriter &rewriter) const override { + // transpose(x, order=[0, 1, ...]) -> x + // We turn it into a (trivial) convert_layout that may be folded away + if (isIota(op.getOrder())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSrc()); + return success(); + } + + // If the layouts are structurally the same, the convert is trivial + auto convert = op.getSrc().getDefiningOp(); + if (!convert || !isConvertTrivial(convert)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getOrder()); + return success(); + } +}; + +// histogram(cvt) -> histogram +struct CanonicalizeConvertFromHistogram + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::HistogramOp op, + PatternRewriter &rewriter) const override { + auto src = op.getSrc(); + auto convert = src.getDefiningOp(); + if (!convert) { + return failure(); + } + src = convert.getSrc(); + + // If mask is present, convert the layout of mask to match new src layout + auto mask = op.getMask(); + if (mask) { + auto sharedType = getI1SameShape(src.getType()); + rewriter.setInsertionPoint(op); + mask = ConvertLayoutOp::create(rewriter, op.getLoc(), sharedType, mask); + } + + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), src, mask); + return success(); + } +}; + +// If the gather does not have an optimized layout attached, then the source +// layout does not matter since the gather will be codegen'd by storing the +// source tensor into shared memory. Thus, we can fold conversions into the +// source operand. +// +// gather(cvt(src), idx) -> gather(src, idx) +struct CanonicalizeConvertFromGatherSource : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { + // Don't do this if the compiler picked an optimized layout. + if (op.getEfficientLayout()) + return failure(); + + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + + rewriter.replaceOpWithNewOp(op, convert.getSrc(), op.getIndices(), + op.getAxis()); + return success(); + } +}; + +// alloc(cvt) -> alloc +struct CanonicalizeConvertFromAlloc + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, + PatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + auto srcTy = dyn_cast(convert.getSrc().getType()); + auto dstTy = dyn_cast(convert.getType()); + if (srcTy && dstTy && isa(srcTy.getEncoding()) && + !isa(dstTy.getEncoding())) { + // Chained SQMMA operands must preserve the explicit mma -> logical-tensor + // boundary. Folding alloc(convert_layout(%mma)) into alloc(%mma) breaks + // the required shared-memory restaging contract for the next SQMMA. + return failure(); + } + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// local_store(cvt) -> local_store +struct CanonicalizeConvertFromLocalStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromSplit + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::SplitOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + auto srcEncoding = convert.getSrc().getType().getEncoding(); + // Multiple source layout can give the same output layout, if the source + // layout of the convert gives the same destination layout we can skip the + // convert. + auto dstEncoding = inferDstEncoding(op, srcEncoding); + if (dstEncoding != op.getOutLHS().getType().getEncoding()) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + PatternRewriter &rewriter) const override { + // Convert to the same layout is redundant. + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + + // We don't handle conversions to DotOperandEncodingAttr. This is a + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding())) + return failure(); + + Operation *arg = op.getSrc().getDefiningOp(); + if (!arg) + return failure(); + + // cvt(reshape) -> reshape + if (auto reshape = dyn_cast(arg)) { + if (!reshape.getAllowReorder() || reshape.getEfficientLayout() || + isExpensiveView(reshape.getSrc().getType(), op.getType())) + return failure(); + + // In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing + // operations, which requires the element type to match between unpacking + // and packing. However, part of values with dot operand encoding will be + // packed/unpacked as i32 elements instead of the underlying element type. + // To avoid errors, skip this folding when either the operand or result + // of view has a dot operand encoding. + if (hasDotOperandEncoding(op->getOperand(0)) || + hasDotOperandEncoding(op->getResult(0))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + reshape.getResult(), + reshape.getAllowReorder()); + return success(); + } + + // cvt(histogram) -> histogram + if (auto histogram = dyn_cast(arg)) { + // For histogram ops the input and output layouts are independent, so we + // can always fold convert into the histogram op. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + histogram.getSrc(), + histogram.getMask()); + return success(); + } + + // cvt(local_load) -> local_load. + if (auto sharedLoad = dyn_cast(arg)) { + // Shared_load can load to any layout so we can always fold convert into + // it. + // We insert at the point of the original op as there could be ops with + // memory side-effects between the LocalLoad op and the ConvertLayout op + rewriter.setInsertionPoint(arg); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + sharedLoad.getSrc(), + sharedLoad.getToken()); + + return success(); + } + + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (auto cvt = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); + return success(); + } + + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return success(); + } + + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return success(); + } + + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = dyn_cast(cst.getValue())) { + auto ty = cast(op->getResultTypes().front()); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return success(); + } + return failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); +} + +LogicalResult Fp4ToFpOp::verify() { + auto srcTy = cast(getSrc().getType()); + auto resTy = cast(getResult().getType()); + auto axis = getAxis(); + + auto elemType = resTy.getElementType(); + if (!(elemType.isBF16() || elemType.isF16())) + return emitError() << "only bf16 or f16 is supported for now, got " + << elemType; + + return verifyFp4ToFp(*this, srcTy, resTy, axis); +} + +LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op, + RankedTensorType srcTy, + RankedTensorType resTy, unsigned axis) { + auto rank = srcTy.getRank(); + + if (rank != resTy.getRank()) + return op->emitError() << "source rank " << rank << " != result rank " + << resTy.getRank(); + + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + + if (!(0 <= axis && axis < rank)) + return op->emitError() << "axis " << axis << " out of range for rank " + << rank; + + for (int i = 0; i < rank; ++i) { + if (i == axis) { + if (resShape[i] != srcShape[i] * 2) + return op->emitError() + << "axis " << axis + << " dimension must be 2x source dimension (src=" << srcShape[i] + << ", dst=" << resShape[i] << ")"; + } else { + if (resShape[i] != srcShape[i]) + return op->emitError() + << "dimension " << i << " mismatch (src=" << srcShape[i] + << ", dst=" << resShape[i] << ", axis=" << axis << ")"; + } + } + if (bool(resTy.getEncoding()) != bool(srcTy.getEncoding())) + return op->emitError() + << "source and result must both have an encoding, or neither"; + if (!resTy.getEncoding()) { + return success(); + } + auto srcLl = toLinearLayout(srcTy); + auto resLl = toLinearLayout(resTy); + auto *ctx = srcTy.getContext(); + auto regDim = StringAttr::get(ctx, "register"); + auto outDims = standardOutDimNames(ctx, rank); + + // We use backward inference here as it is striclty more general + Attribute inferSrc; + auto dialect = + resTy.getEncoding() + .getDialect() + .getRegisteredInterface(); + assert(dialect); + if (failed(dialect->inferFp4ToFpOpEncoding( + resTy.getShape(), axis, resTy.getEncoding(), inferSrc, + /*fwdInference*/ false, std::nullopt))) { + return op->emitError() << "failed to infer encoding"; + } + if (!areLayoutsEquivalent(srcTy.getShape(), + cast(inferSrc), + cast(srcTy.getEncoding()))) + return op->emitError() + << "Src and Dst encodings are not compatible:\n" + << toLinearLayout(srcTy.getShape(), inferSrc).toString() << "\n" + << srcLl.toString(); + return success(); +} + +void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state, + TypedValue src, Type elemType, + int32_t axis) { + auto srcTy = src.getType(); + auto shape = llvm::to_vector(srcTy.getShape()); + auto rank = srcTy.getRank(); + assert(0 <= axis && axis < rank); + shape[axis] *= 2; + + Attribute inEnc = srcTy.getEncoding(); + Attribute outEnc; + auto result = + inEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc, + /*fwdInference=*/true, state.location); + assert(succeeded(result)); + + auto resultTy = RankedTensorType::get(shape, elemType, outEnc); + build(builder, state, resultTy, src, axis); +} + +OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult +MemDescTransOp::inferReturnTypes(MLIRContext *context, + std::optional loc, + MemDescTransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + + // type is the same as the input + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferTransOpEncoding( + argEncoding, shape, order, retEncoding, loc))) { + return failure(); + } + } + + // Permute the last `rank` dims of the source alloc shape. + SmallVector allocShape = + applyPermutation(argTy.getAllocShape().take_back(order.size()), order); + allocShape.insert(allocShape.begin(), argTy.getAllocShape().begin(), + argTy.getAllocShape().end() - order.size()); + + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(), + argTy.getMutableMemory(), allocShape)); + return success(); +} + +// MemDescReshapeOp +LogicalResult MemDescReshapeOp::verify() { + MemDescType dstType = getResult().getType(); + MemDescType srcType = getSrc().getType(); + if (product(dstType.getShape()) != product(srcType.getShape())) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + if (dstType.getElementType() != srcType.getElementType()) { + return emitError("result element type must match src element type"); + } + auto srcShape = srcType.getShape(); + if (srcType.getAllocShape().take_back(srcShape.size()) != srcShape) { + return emitError("NYI: memdesc_reshape of memdesc_subslice"); + } + + MemDescType expectedTy; + if (failed(inferReturnTypes(getContext(), getLoc(), srcType, + dstType.getShape(), expectedTy))) + return failure(); + return OpTrait::impl::verifyEquivalentType(expectedTy, dstType); +} + +static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) { + auto *ctx = srcEnc.getContext(); + // TODO Delete this once SharedLinearEncodingAttr is more widely supported. + if (auto mmaEncoding = dyn_cast(srcEnc)) { + if (getNumCTAs(mmaEncoding) == 1) { + int innerDimDst = + mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back(); + int innerDimSrc = + mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back(); + // We can keep an NVMMAShared encoding only if the innermost dimension is + // preserved. Otherwise fall back to the generic shared-linear encoding + // logic below. + if (innerDimDst == innerDimSrc) { + auto CGALayout = CGAEncodingAttr::get1CTALayout(ctx, dstShape.size()); + auto candidateEncoding = NVMMASharedEncodingAttr::get( + ctx, mmaEncoding.getSwizzlingByteWidth(), + mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(), + mmaEncoding.getFp4Padded(), CGALayout); + auto srcLL = toLinearLayout(srcShape, srcEnc); + auto dstLL = toLinearLayout(dstShape, candidateEncoding); + if (reshapeLayout(ctx, srcLL, dstShape) == dstLL) { + dstEnc = candidateEncoding; + return success(); + } + } + } + } else if (auto padded = dyn_cast(srcEnc)) { + LinearLayout ll = padded.getLinearComponent(); + LinearLayout dst = reshapeLayout(ctx, ll, dstShape); + SmallVector> intervalPads; + auto intervals = padded.getIntervals(); + auto paddings = padded.getPaddings(); + for (auto [interval, padding] : llvm::zip(intervals, paddings)) { + intervalPads.emplace_back(interval, padding); + } + dstEnc = PaddedSharedEncodingAttr::get(ctx, intervalPads, std::move(dst)); + return success(); + } + + // Generic LL case + auto sharedEnc = cast(srcEnc); + auto srcLL = toLinearLayout(srcShape, srcEnc); + auto dstLL = reshapeLayout(ctx, srcLL, dstShape); + dstEnc = SharedLinearEncodingAttr::get(ctx, std::move(dstLL), + sharedEnc.getAlignment()); + return success(); +} + +LogicalResult MemDescReshapeOp::inferReturnTypes( + MLIRContext *context, std::optional loc, MemDescType srcTy, + ArrayRef dstShape, MemDescType &inferredReturnType) { + if (product(dstShape) != product(srcTy.getShape())) + return emitOptionalError( + loc, "dst shape has different number of elements than src"); + + Attribute dstEncoding; + if (Attribute srcEnc = srcTy.getEncoding()) { + if (failed(inferMemDescReshapeOpEncoding(srcTy.getShape(), srcEnc, dstShape, + dstEncoding))) + return failure(); + } + + SmallVector dstAllocShape = + to_vector(srcTy.getAllocShape().take_front(srcTy.getAllocShape().size() - + srcTy.getShape().size())); + dstAllocShape.append(dstShape.begin(), dstShape.end()); + + inferredReturnType = MemDescType::get( + dstShape, srcTy.getElementType(), dstEncoding, srcTy.getMemorySpace(), + srcTy.getMutableMemory(), dstAllocShape); + return success(); +} + +OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) + return getSrc(); + return {}; +} + +// LocalAllocOp +void LocalAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset")) + return; + OpResult alloc = getOperation()->getOpResult(0); + effects.emplace_back(MemoryEffects::Allocate::get(), alloc, + SharedMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), alloc, + SharedMemory::get()); +} + +OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) { + if (getType().getMutableMemory()) + return {}; + auto src = getSrc(); + if (!src) + return {}; + auto localLoadOp = src.getDefiningOp(); + if (!localLoadOp) + return {}; + auto loadSrc = localLoadOp.getSrc(); + if (loadSrc.getType() != getType()) + return {}; + return loadSrc; +} + +int32_t LocalAllocOp::getAlignmentOrDefault() { + auto align = getAlignment(); + if (align) { + return *align; + } + + auto ty = getType(); + auto enc = dyn_cast(ty.getEncoding()); + return enc ? enc.getAlignment() : 16; +} + +LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy, + ShapedType dstTy) { + if (srcTy.getElementType() != dstTy.getElementType()) { + return op->emitOpError("source element type ") + << srcTy << " must match " + << "destination element type " << dstTy.getElementType(); + } + if (srcTy.getShape() != dstTy.getShape()) { + return op->emitOpError("source shape [") + << srcTy.getShape() << "] must match [" + << "destination shape " << dstTy.getShape() << "]"; + } + return success(); +} + +LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy) { + if (dstTy.getShape() != dstTy.getAllocShape()) + return op->emitOpError("result shape and its alloc shape must match"); + + if (!src) { + if (!dstTy.getMutableMemory()) { + return op->emitOpError( + "uninitialized alloc must have a mutable memdesc type"); + } + return success(); + } + + return verifyMemoryOpTypes(op, cast(src.getType()), dstTy); +} + +static LogicalResult verifySharedMemoryRank(Operation *op, + RankedTensorType type, + MemDescType memdesc, + StringRef regName) { + auto enc = dyn_cast(memdesc.getEncoding()); + if (!enc) + return op->emitOpError("expected memdesc to have a shared memory encoding"); + if (type.getRank() != enc.getRank()) { + return op->emitOpError(regName) + << " has rank " << type.getRank() + << " but memdesc encoding has rank " << enc.getRank(); + } + return success(); +} + +LogicalResult LocalAllocOp::verify() { + if (!isa(getType().getMemorySpace())) + return emitOpError("should create a buffer of shared memory"); + if (getSrc() && failed(verifySharedMemoryRank(*this, getSrc().getType(), + getType(), "source"))) + return failure(); + return verifyAllocOp(*this, getSrc(), getType()); +} + +// LocalStoreOp +LogicalResult LocalStoreOp::verify() { + if (!getDst().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + if (failed(verifySharedMemoryRank(*this, getSrc().getType(), + getDst().getType(), "source"))) + return failure(); + return verifyMemoryOpTypes(*this, getSrc().getType(), getDst().getType()); +} + +// LocalLoadOp +LogicalResult LocalLoadOp::verify() { + if (failed(verifySharedMemoryRank(*this, getType(), getSrc().getType(), + "result"))) + return failure(); + return verifyMemoryOpTypes(*this, getSrc().getType(), getType()); +} + +// LocalGatherOp +LogicalResult LocalGatherOp::verify() { + auto srcTy = getSrc().getType(); + auto indicesTy = cast(getIndices().getType()); + auto dstTy = cast(getType()); + unsigned axis = getAxis(); + + // Verify source has shared memory encoding + auto srcEnc = srcTy.getEncoding(); + if (!isa(srcEnc)) { + return emitError("source must have shared memory encoding"); + } + + // Verify indices tensor has integer element type + if (!indicesTy.getElementType().isInteger()) { + return emitError("indices must have integer element type"); + } + + // Verify result has the same shape as indices + if (dstTy.getShape() != indicesTy.getShape()) { + return emitError("result shape must match indices shape"); + } + + // Verify src and indices have the same rank + if (srcTy.getRank() != indicesTy.getRank()) { + return emitError("source and indices must have the same rank"); + } + + // Verify axis is valid + if (axis >= srcTy.getRank()) { + return emitError("axis ") + << axis << " is out of bounds for source rank " << srcTy.getRank(); + } + + // Verify element types match + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match source element type"); + } + + // Verify indices and result have the same layout + if (indicesTy.getEncoding() != dstTy.getEncoding()) { + return emitError("indices and result must have the same layout"); + } + + return success(); +} + +// LocalScatterOp +LogicalResult LocalScatterOp::verify() { + auto dstTy = getDst().getType(); + auto valuesTy = cast(getValues().getType()); + auto indicesTy = cast(getIndices().getType()); + unsigned axis = getAxis(); + + // Verify destination has shared memory encoding + auto dstEnc = dstTy.getEncoding(); + if (!isa(dstEnc)) { + return emitError("destination must have shared memory encoding"); + } + + // Verify indices tensor has integer element type + if (!indicesTy.getElementType().isInteger()) { + return emitError("indices must have integer element type"); + } + + // Verify values and indices have the same shape + if (valuesTy.getShape() != indicesTy.getShape()) { + return emitError("values shape must match indices shape"); + } + + // Verify dst and indices have the same rank + if (dstTy.getRank() != indicesTy.getRank()) { + return emitError("destination and indices must have the same rank"); + } + + // Verify axis is valid + if (axis >= dstTy.getRank()) { + return emitError("axis ") + << axis << " is out of bounds for destination rank " + << dstTy.getRank(); + } + + // Verify values and indices have the same layout + if (valuesTy.getEncoding() != indicesTy.getEncoding()) { + return emitError("values must have the same layout as indices"); + } + + // Verify element types match + if (dstTy.getElementType() != valuesTy.getElementType()) { + return emitError("values element type must match destination element type"); + } + + return success(); +} + +// AsyncCopyGlobalToLocalOp +LogicalResult AsyncCopyGlobalToLocalOp::verify() { + if (!getResult().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + +LogicalResult MemDescIndexOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + // memdesc_index reduces rank by 1 and preserves the trailing shape. + bool correctRank = srcTy.getRank() == dstTy.getRank() + 1; + if (!correctRank) { + return emitError("result rank must be input rank - 1"); + } + if (srcTy.getAllocShape().size() != srcTy.getRank()) { + return emitError("We don't allow taking memdesc_index of a memdesc_index"); + } + + if (ArrayRef(srcTy.getShape()).take_back(dstTy.getRank()) != + dstTy.getShape()) { + return emitError("result shape must equal to srcShape[1:]"); + } + + bool isSubview = srcTy.getAllocShape() != srcTy.getShape(); + if (isSubview) { + return emitError("We don't support memdesc_index of a subview"); + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (bool(srcEnc) != bool(dstEnc)) { + return emitError("src and result must both have or not have an encoding"); + } + + if (isa(srcEnc) != isa(dstEnc)) { + return emitError("src and dst must have the same type of encoding"); + } + + if (dstTy.getAllocShape() != dstTy.getShape() || + srcTy.getAllocShape() != srcTy.getShape()) { + return emitError("alloc shape must match shape for both result and src"); + } + + if (isa(srcEnc)) { + // We support only 3D -> 2D subviews with only first offset being non-zero. + if (srcTy.getRank() != 3 || dstTy.getRank() != 2) { + return emitError("only 3D -> 2D subviews are supported for " + "TensorMemoryEncodingAttr"); + } + return success(); + } + return success(); +} + +OpFoldResult MemDescSubsliceOp::fold(FoldAdaptor adaptor) { + // Fold subslice(subslice(x, off1), off2) -> subslice(x, off1 + off2) + if (auto srcSubslice = getSrc().getDefiningOp()) { + auto srcOffsets = srcSubslice.getOffsets(); + auto currOffsets = getOffsets(); + + // Compute combined offsets + SmallVector combinedOffsets; + for (size_t i = 0; i < currOffsets.size(); ++i) { + combinedOffsets.push_back(srcOffsets[i] + currOffsets[i]); + } + + // Update this operation to point directly to the original source with + // combined offsets + setOperand(srcSubslice.getSrc()); + setOffsetsAttr(DenseI32ArrayAttr::get(getContext(), combinedOffsets)); + return getResult(); + } + + return {}; +} + +LogicalResult MemDescSubsliceOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); + } + if (srcTy.getRank() != dstTy.getRank()) { + return emitError("result rank must equal to input rank"); + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (bool(srcEnc) != bool(dstEnc)) { + return emitError("src and result must both have or not have an encoding"); + } + if (!isa(srcEnc) || !isa(dstEnc)) { + return emitError("src and dst must both be of shared memory encoding"); + } + + SetVector splitDims{}; + for (int i = 0; i < srcTy.getRank(); i++) { + if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) { + splitDims.insert(i); + } + } + SmallVector offsets(getOffsets().begin(), getOffsets().end()); + // Identity subview + if (splitDims.empty()) { + return success(); + } + + for (auto [dim, offset] : llvm::enumerate(offsets)) { + if (!splitDims.contains(dim)) { + if (offset != 0) { + return emitError("A non zero offset found in a dimension that is " + "not being split"); + } + } else { + if (offset & (dstTy.getDimSize(dim) - 1)) { + return emitError("The split offset may not touch the tile"); + } + } + } + + auto ctx = getContext(); + LinearLayout ll; + if (auto paddedEncoding = dyn_cast(srcEnc)) { + if (paddedEncoding.getRank() < srcTy.getRank()) { + return emitError("SubSlice of low rank PaddedSharedEncoding from higher " + "rank tensors is not supported yet"); + } + ll = paddedEncoding.getLinearComponent(); + } else { + ll = triton::gpu::toLinearLayout(srcTy); + } + // NYI: We don't support non-trivial block dimension for now. + auto kBlock = mlir::StringAttr::get(getContext(), "block"); + if (ll.getInDimSize(kBlock) != 1) { + return emitError("non-trivial block dimension not supported"); + } + + auto llInv = ll.invert(); + for (auto dim : splitDims) { + auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim)); + llvm::SmallVector> namedOffsets; + for (auto d : standardOutDimNames(ctx, srcTy.getRank())) { + namedOffsets.push_back({d, 0}); + } + for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim); + dimSize *= 2) { + namedOffsets[dim] = {kDim, dimSize}; + if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) { + return emitError( + "We don't support splitting along the swizzling pattern"); + } + } + } + return success(); +} + +// -- WarpSpecializeOp -- + +RegionRange WarpSpecializeOp::getPartitionRegions() { + return getPartitionOp().getPartitionRegions(); +} + +WarpSpecializePartitionsOp WarpSpecializeOp::getPartitionOp() { + return cast( + getPartitionOpHolder().front().front()); +} + +void WarpSpecializeOp::getSuccessorRegions( + RegionBranchPoint src, SmallVectorImpl &successors) { + // The parent branches into the default region and the partition regions. + if (src.isParent()) { + successors.emplace_back(&getDefaultRegion()); + successors.emplace_back(&getPartitionOpHolder()); + return; + } + // And the default region branches transparently back to the parent. + if (src.getTerminatorPredecessorOrNull()->getParentRegion() == + &getDefaultRegion()) + successors.push_back(RegionSuccessor(getOperation(), getResults())); +} + +void WarpSpecializePartitionsOp::getSuccessorRegions( + RegionBranchPoint src, SmallVectorImpl &successors) { + // The parent branches to each of the partition regions, but nothing flows out + // of the partition regions. + if (src.isParent()) + for (Region ®ion : getPartitionRegions()) + successors.emplace_back(®ion, region.getArguments()); +} + +OperandRange +WarpSpecializePartitionsOp::getEntrySuccessorOperands(RegionSuccessor) { + // Pass through the explicit captures from the enclosing WarpSpecializeOp. + return getExplicitCaptures(); +} + +LogicalResult WarpSpecializeOp::verify() { + // The default region is not isolated from above but the partition regions + // have to be. MLIR does not support this, so we hide an op inside another + // region that contains the isolated regions. Check that it is there. + if (!isa( + getPartitionOpHolder().front().front())) { + return emitOpError( + "expected to find only a `ttg.warp_specialize.partitions` op inside " + "its second region"); + } + + // Verify the partitions. + if (getPartitionRegions().size() != getPartitionNumWarps().size()) { + return emitOpError("has ") << getPartitionRegions().size() + << " partitions but `partitionNumWarps` has " + << getPartitionNumWarps().size() << " elements"; + } + for (auto [i, numWarps] : llvm::enumerate(getPartitionNumWarps())) { + if (llvm::isPowerOf2_32(numWarps)) + continue; + return emitOpError("partition #") + << i << " number of warps (" << numWarps << ") must be a power of 2"; + } + if (std::optional> startIds = getWarpGroupStartIds()) { + if (startIds->size() != getPartitionNumWarps().size()) { + return emitOpError("has ") + << startIds->size() << " warp group start IDs but expected " + << getPartitionNumWarps().size(); + } + } + + // This op cannot be nested inside itself. + if ((*this)->getParentOfType()) { + return emitOpError( + "cannot be nested inside another `ttg.warp_specialize` op"); + } + + std::optional numWarps = maybeLookupNumWarps(*this); + if (numWarps && *numWarps % 4 != 0) { + return mlir::emitError(getLoc()) << "warp-specialized kernels requires " + "num_warps to be a multiple of 4"; + } + + return success(); +} + +LogicalResult WarpSpecializeOp::canonicalize(WarpSpecializeOp op, + PatternRewriter &b) { + // Propagate unused results and captures by removing them from the op. + llvm::BitVector unusedResults(op.getNumResults()); + for (auto [i, result] : llvm::enumerate(op.getResults())) { + if (result.use_empty()) + unusedResults.set(i); + } + + if (unusedResults.none()) + return failure(); + + for (Block &block : op.getDefaultRegion()) { + if (auto yield = dyn_cast(block.getTerminator())) { + b.modifyOpInPlace(yield, [&] { yield->eraseOperands(unusedResults); }); + } + } + + SmallVector newTypes; + for (auto [i, type] : llvm::enumerate(op.getResultTypes())) { + if (!unusedResults.test(i)) + newTypes.push_back(type); + } + OperationState state(op.getLoc(), op->getName(), {}, newTypes, + op->getAttrs()); + state.addRegion()->takeBody(op.getDefaultRegion()); + state.addRegion()->takeBody(op.getPartitionOpHolder()); + auto newOp = cast(b.create(state)); + unsigned newResultIdx = 0; + for (auto [i, result] : llvm::enumerate(op.getResults())) { + if (!unusedResults.test(i)) + result.replaceAllUsesWith(newOp.getResult(newResultIdx++)); + } + assert(newResultIdx == newOp.getNumResults()); + b.eraseOp(op); + + return success(); +} + +void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + ArrayRef partitionNumWarps, + unsigned partitionNumRegions) { + build(builder, state, resultTypes, partitionNumWarps, {}, {}, {}); + OpBuilder::InsertionGuard guard(builder); + Block *container = builder.createBlock(state.regions.back().get()); + WarpSpecializePartitionsOp::create(builder, state.location, + /*explicitCaptures=*/ValueRange(), + partitionNumRegions); +} + +void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + ArrayRef partitionNumWarps) { + build(builder, state, resultTypes, partitionNumWarps, {}, {}, {}); +} + +ParseResult WarpSpecializeOp::parse(OpAsmParser &p, OperationState &result) { + SmallVector operands; + SMLoc operandLoc = p.getCurrentLocation(); + if (p.parseOperandList(operands, AsmParser::Delimiter::Paren) || + p.parseOptionalAttrDictWithKeyword(result.attributes) || + p.parseKeyword("default") || p.parseRegion(*result.addRegion())) + return failure(); + + OperationState partitionOpState( + p.getEncodedSourceLoc(p.getCurrentLocation()), + WarpSpecializePartitionsOp::getOperationName()); + + SmallVector partitionNumWarps; + SmallVector partitionArgs; + while (succeeded(p.parseOptionalKeyword( + ("partition" + Twine(partitionNumWarps.size()).str())))) { + partitionArgs.clear(); + SMLoc regionLoc = p.getCurrentLocation(); + if (p.parseArgumentList(partitionArgs, AsmParser::Delimiter::Paren, + /*allowType=*/true) || + p.parseKeyword("num_warps") || p.parseLParen() || + p.parseInteger(partitionNumWarps.emplace_back()) || p.parseRParen() || + p.parseRegion(*partitionOpState.addRegion(), partitionArgs)) + return failure(); + } + + FunctionType types; + if (p.parseColon() || p.parseType(types) || + p.resolveOperands(operands, types.getInputs(), operandLoc, + partitionOpState.operands)) + return failure(); + + result.addTypes(types.getResults()); + result.addAttribute(getPartitionNumWarpsAttrName(result.name), + p.getBuilder().getDenseI32ArrayAttr(partitionNumWarps)); + + Block &holder = result.addRegion()->emplaceBlock(); + OpBuilder b(p.getContext()); + b.setInsertionPointToStart(&holder); + b.create(partitionOpState); + return success(); +} + +void WarpSpecializeOp::print(OpAsmPrinter &p) { + p << '('; + p.printOperands(getPartitionOp().getOperands()); + p << ')'; + p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + {getPartitionNumWarpsAttrName()}); + + p.printNewline(); + p << "default "; + p.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/false); + + for (auto [i, region, numWarps] : + llvm::enumerate(getPartitionRegions(), getPartitionNumWarps())) { + p.printNewline(); + p << "partition" << i << '('; + llvm::interleaveComma(region->getArguments(), p, [&](BlockArgument arg) { + p.printRegionArgument(arg); + }); + p << ") num_warps(" << numWarps << ") "; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } + p << " : "; + SmallVector captureTypes; + for (auto val : getPartitionOp().getExplicitCaptures()) + captureTypes.push_back(val.getType()); + p.printFunctionalType(captureTypes, getResultTypes()); +} + +LogicalResult WarpSpecializePartitionsOp::verify() { + for (auto [i, region] : llvm::enumerate(getPartitionRegions())) { + if (region.getNumArguments() != getNumOperands()) { + return emitOpError("partition region #") + << i << " has " << region.getNumArguments() + << " arguments but expected " << getNumOperands(); + } + for (auto [argIdx, argType, capType] : llvm::enumerate( + region.getArgumentTypes(), getExplicitCaptures().getTypes())) { + if (argType == capType) + continue; + return emitOpError("partition region #") + << i << " argument #" << argIdx << " has type " << argType + << " but corresponding capture has type " << capType; + } + } + return success(); +} + +LogicalResult +WarpSpecializePartitionsOp::canonicalize(WarpSpecializePartitionsOp op, + PatternRewriter &b) { + llvm::BitVector unusedArgs(op.getNumOperands()); + + // Remove duplicate captures. + DenseMap uniqueCaptures; + for (auto [i, capture] : llvm::enumerate(op.getExplicitCaptures())) { + auto noUseInRegion = [i = i](Region ®ion) { + return region.getArgument(i).use_empty(); + }; + if (llvm::all_of(op.getPartitionRegions(), noUseInRegion)) { + unusedArgs.set(i); + continue; + } + + auto [it, inserted] = uniqueCaptures.try_emplace(capture, i); + if (!inserted) { + unsigned duplicateIdx = it->second; + b.modifyOpInPlace(op, [&, i = i] { + for (Region ®ion : op.getPartitionRegions()) { + b.replaceAllUsesWith(region.getArgument(i), + region.getArgument(duplicateIdx)); + } + }); + unusedArgs.set(i); + } + } + + if (unusedArgs.none()) + return failure(); + + b.modifyOpInPlace(op, [&] { + for (Region ®ion : op.getPartitionRegions()) + region.front().eraseArguments(unusedArgs); + op->eraseOperands(unusedArgs); + }); + return success(); +} + +LogicalResult WarpYieldOp::verify() { + if (getNumOperands() != getParentOp().getNumResults()) { + return emitOpError("has ") + << getNumOperands() << " operands but parent op expected " + << getParentOp().getNumResults(); + } + for (auto [i, result, type] : + llvm::enumerate(getParentOp().getResultTypes(), getOperandTypes())) { + if (result != type) { + return emitOpError("operand #") << i << " has type " << type + << " but parent op expected " << result; + } + } + return success(); +} + +// Get the size of a scalar type when stored in shared memory. +// TODO: Generalize this as needed. +static size_t getSharedMemorySize(Type type) { + if (isa(type)) + return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8); + if (isa(type)) + return 8; + if (auto desc = dyn_cast(type)) { + if (!isa(desc.getMemorySpace())) + return 8; + return 8 + desc.getRank() * 4; + } + llvm::report_fatal_error( + Twine("shared memory size for scalar type is unspecified: ") + + mlir::debugString(type)); +} + +std::pair WarpSpecializeOp::getCaptureSizeAlign() { + uint64_t captureSize = 0; + // Tightly pack the captures in memory. + for (Type type : getPartitionOp().getOperandTypes()) { + captureSize += getSharedMemorySize(type); + } + // Align the captures to 8 bytes. + return {captureSize, 8}; +} + +unsigned WarpSpecializeOp::getTotalPartitionWarps() { + ArrayRef numWarps = getPartitionNumWarps(); + return std::accumulate(numWarps.begin(), numWarps.end(), 0); +} + +//===----------------------------------------------------------------------===// +// BarrierOp +//===----------------------------------------------------------------------===// + +void BarrierOp::print(OpAsmPrinter &p) { + // print "all" instead of "local|global_read|global_write|tensor|all" + if (getAddrSpace() == AddrSpace::All) { + p << " all"; + } else { + p << ' ' << stringifyAddrSpace(getAddrSpace()); + } +} + +ParseResult BarrierOp::parse(OpAsmParser &parser, OperationState &result) { + auto parseAddrSpace = [&]() -> FailureOr { + std::string keyword; + if (parser.parseKeywordOrString(&keyword)) + return failure(); + + auto addrSpace = symbolizeAddrSpace(keyword); + if (!addrSpace) + return parser.emitError(parser.getCurrentLocation()) + << "unknown addrSpace '" << keyword << "'"; + + return *addrSpace; + }; + + auto addrSpace = parseAddrSpace(); + if (failed(addrSpace)) + return failure(); + + AddrSpace addrSpaceRet = *addrSpace; + + while (succeeded(parser.parseOptionalVerticalBar())) { + addrSpace = parseAddrSpace(); + if (failed(addrSpace)) + return failure(); + + addrSpaceRet = bitEnumSet(addrSpaceRet, *addrSpace); + } + + result.addAttribute("addrSpace", + AddrSpaceAttr::get(parser.getContext(), addrSpaceRet)); + + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 0000000000..5754fe7a6a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,226 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + if (failed(parser.parseLess())) + return Type(); + + SmallVector dimensions; // required + if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false))) + return Type(); + + Type elementType; // required + if (failed(parser.parseType(elementType))) + return Type(); + + Attribute encoding; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding))) + return Type(); + + Attribute memorySpace; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace))) + return Type(); + + bool mutableMemory = false; // optional + SmallVector allocShape; // optional + if (succeeded(parser.parseOptionalComma())) { + if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) { + mutableMemory = true; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + } else if (failed(parser.parseDimensionList(allocShape, + /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + + if (parser.parseGreater()) + return Type(); + + if (!allocShape.empty()) + return MemDescType::getChecked(loc, parser.getContext(), dimensions, + elementType, encoding, memorySpace, + mutableMemory, allocShape); + + return MemDescType::getChecked(loc, parser.getContext(), dimensions, + elementType, encoding, memorySpace, + mutableMemory, dimensions); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + auto shape = getShape(); + for (auto dim : shape) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + auto allocShape = getAllocShape(); + if (allocShape != shape) { + printer << ", " << allocShape[0]; + for (auto dim : allocShape.drop_front(1)) { + printer << "x" << dim; + } + } + printer << ">"; +} + +LogicalResult MemDescType::verify(function_ref emitError, + ArrayRef shape, Type elementType, + Attribute encoding, Attribute memorySpace, + bool mutableMemory, + ArrayRef allocShape) { + if (shape.empty()) { + return emitError() << "rank 0 memdesc is not allowed"; + } + // Every dimension but the first (to allow for pipelining) must be a power of + // 2 + if (!llvm::all_of(shape.drop_front(1), [](int64_t dim) { + return llvm::isPowerOf2_64(dim) && dim > 0; + })) + return emitError() + << "shape must have power-of-2 and non-zero dimensions; got " + << shape; + if (shape.front() == 0) + return emitError() << "shape has 0 dimension"; + if (allocShape.size() < shape.size()) + return emitError() + << "alloc shape must have at least as many dimensions as shape"; + if (llvm::any_of( + llvm::zip(shape, allocShape.take_back(shape.size())), + [](auto pair) { return std::get<0>(pair) > std::get<1>(pair); })) + return emitError() << "shape must be less than or equal to allocShape. " + << "shape = " << shape + << ", allocShape = " << allocShape; + auto ctx = encoding.getContext(); + if (auto enc = dyn_cast(encoding)) { + if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) { + return emitError() << "memorySpace must be TensorMemorySpace"; + } + if (shape.size() != 2 && shape.size() != 3) { + return emitError() << "rank must be 2 or 3"; + } + unsigned bitwidth = elementType.getIntOrFloatBitWidth(); + if (bitwidth * enc.getColStride() > 32) { + return emitError() + << "bitwidth * colStride must be less than or equal to 32. Got " + << bitwidth << " and " << enc.getColStride(); + } + shape = shape.take_back(2); + allocShape = allocShape.take_back(2); + if (allocShape[0] < enc.getBlockM() * enc.getCTASplitM() || + allocShape[1] < enc.getBlockN() * enc.getCTASplitN()) { + return emitError() << "the allocation shape must be at least " + << enc.getBlockM() * enc.getCTASplitM() << "x" + << enc.getBlockN() * enc.getCTASplitN() << ". Got " + << allocShape; + } + auto ll = toLinearLayout(allocShape, enc); + auto dims = standardOutDimNames(ctx, 2); + if (ll.getOutDimSize(dims[0]) != allocShape[0] || + ll.getOutDimSize(dims[1]) != allocShape[1]) { + return emitError() << "allocation shape must be equal to " + << ll.getOutDimSize(dims[0]) << "x" + << ll.getOutDimSize(dims[1]); + } + } else if (auto enc = dyn_cast(encoding)) { + if (memorySpace != SharedMemorySpaceAttr::get(ctx)) { + return emitError() + << "memorySpace must be SharedMemorySpace for shared encoding. " + << "Got " << memorySpace; + } + auto rank = cast(enc).getRank(); + if (!(rank == shape.size() || rank == shape.size() - 1)) { + return emitError() << "rank must be equal to or one less than " + << "the shape size. Got " << rank << " and " + << shape.size(); + } + } else if (auto enc = dyn_cast( + encoding)) { + if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) { + return emitError() << "memorySpace must be TensorMemorySpace"; + } + if (allocShape.size() != 2) { + return emitError() << "Scales don't currently support multibuffering"; + } + auto bitwidth = elementType.getIntOrFloatBitWidth(); + if (bitwidth != 8) { + return emitError() << "bitwidth must be 8"; + } + } else { + return emitError() << encoding << " is not a valid encoding"; + } + + // PaddedSharedEncodingAttr is also a SharedEncodingTrait but we have some + // additional rules to verify. + if (auto enc = dyn_cast(encoding)) { + auto rank = enc.getRank(); + // Ensure linear component's outDims match the alloc size ignoring + // pipelining dimension + auto outDims = standardOutDimNames(ctx, rank); + const auto &ll = enc.getLinearComponent(); + auto expectedShape = allocShape; + if (rank == allocShape.size() - 1) + expectedShape = expectedShape.drop_front(1); + + for (auto d = 0; d < rank; d++) { + if (ll.getOutDimSize(outDims[d]) != expectedShape[d]) { + return emitError() << "Mismatch in expected shape for dimension " << d + << ". Expected: " << expectedShape[d] + << ", got: " << ll.getOutDimSize(outDims[d]); + } + } + } else if (auto enc = dyn_cast(encoding)) { + SmallVector shapePerCTA(getShapePerCTA(enc, allocShape)); + auto blockShape = ArrayRef(shapePerCTA).take_back(enc.getRank()); + if (failed(getTMABlockShape(blockShape, enc.getElementBitWidth(), + enc.getSwizzlingByteWidth(), enc.getFp4Padded(), + enc.getTransposed(), /*packedSize=*/false, + emitError))) + return failure(); + } else if (auto enc = dyn_cast(encoding)) { + auto blockShape = ArrayRef(allocShape).take_back(enc.getRank()); + const LinearLayout &ll = enc.getLinearLayout(); + for (auto [dim, size, llSize] : + llvm::enumerate(blockShape, ll.getOutDimSizes())) { + if (size == llSize) + continue; + return emitError() << "Mismatch in expected shape for dimension " << dim + << ". Expected: " << size << ", got: " << llSize; + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp new file mode 100644 index 0000000000..ebc995f145 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -0,0 +1,1027 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Get the highest version supported for the hardware and the dot. +static int getMMAVersionSafe(int computeCapability, DotOp op) { + // List supported mma version in order of preference. + SmallVector versionsSupported; + if (computeCapability < 75) { + versionsSupported = {1}; + } else if (computeCapability < 90) { + versionsSupported = {2}; + } else if (computeCapability < 100) { + versionsSupported = {3, 2}; + } else if (computeCapability < 120) { + // Exclude consumer Blackwell (sm120) + versionsSupported = {5, 2}; + } else if (computeCapability < 130) { + versionsSupported = {2}; + } else { + assert(false && "computeCapability not supported"); + } + for (int baseVersion : versionsSupported) { + if (supportMMA(op, baseVersion)) + return baseVersion; + if (baseVersion == 3) { + auto remark = op.emitRemark() + << "MMA version 3 acceleration not applied due to " + "unsupported shapes or data types."; + remark.attachNote() << "Target compute capability (" << computeCapability + << ") supports MMA v3."; + } + + if (baseVersion == 5) { + auto remark = op.emitRemark() + << "MMA version 5 acceleration not applied due to " + "unsupported shapes or data types."; + remark.attachNote() << "Target compute capability (" << computeCapability + << ") supports MMA v5."; + } + } + return 0; +} + +SmallVector warpsPerTileV2(DotOpInterface dotOp, + const ArrayRef shape, + int numWarps) { + auto rank = shape.size(); + // Early exit for batched matmul + if (rank == 3) + return {(unsigned)numWarps, 1, 1}; + + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion() && + !isa(op); + }; + auto slices = mlir::getSlice(dotOp, {filter}, {filter}); + bool hasChainedDot = false; + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) { + auto resTy = cast(op->getResult(0).getType()); + if (resTy.getRank() != rank) { + continue; + } + if (auto mmaEncoding = + dyn_cast(resTy.getEncoding())) { + return to_vector(mmaEncoding.getWarpsPerCTA()); + } + hasChainedDot = true; + } + } + if (hasChainedDot) { + if (shape[0] >= shape[1]) { + return {(unsigned)numWarps, 1}; + } else { + return {1, (unsigned)numWarps}; + } + } + + assert(rank == 2); + SmallVector shapePerWarp = {16, 8}; + SmallVector warps = {1, 1}; + // Compute repM and repN + SmallVector reps = {ceil(shape[0], shapePerWarp[0]), + ceil(shape[1], shapePerWarp[1])}; + // The formula for the number of registers given the reps is + // repM * 4 * repK + repN * 2 * repK + regsC + // where regsC = repM * repN * 4, which does not depend on the warp shape + // + // As such, to minimize the register pressure, we need to balance + // repM and repN. We then untie towards M, as the lhs tile has 4 elements, + // and the rhs tile has just 2. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + // Too many warps for this mma (repM == repN == 1). + // We allocate the remaining warps to the left (arbitrary choice) + if (reps[0] != 1) { + reps[0] /= 2; + } + } else { + warps[1] *= 2; + reps[1] /= 2; + } + } + return {(unsigned)warps[0], (unsigned)warps[1]}; +} +SmallVector +warpsPerTileV3(DotOpInterface dotOp, const ArrayRef shape, + int numWarps, const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getD(), &slices); + // Contains a chained dot. We prefer to assign warps to one axis + // to facilitate use cases like flash attention, allowing reductions within + // the same warp. + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return {(unsigned)numWarps, 1}; + + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). + SmallVector ret = {4, 1}; + SmallVector shapePerWarp = {16, instrShape[1]}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] > shapePerWarp[0] * ret[0]) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +// Returns a shared memory allocation that can be used by a dotMMA op for the +// given value. +static Value +getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx, + bool allowTranspose, bool isMMAv5Fp4Padded = false, + bool forceTranspose = false, + Operation *op = nullptr /*only for diagnostic*/) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + while (auto cvtOp = arg.getDefiningOp()) + arg = cvtOp.getSrc(); + auto argType = cast(arg.getType()); + assert(argType.getEncoding() && "unexpected tensor type"); + auto order = getOrderForMemory(argType); + + // If the MMA op doesn't support transpose pick the layout expected by the MMA + // op. + llvm::SmallVector newOrder = order; + if (!allowTranspose) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + if (forceTranspose) + std::swap(newOrder[0], newOrder[1]); + } + + if (newOrder != order && op) { + op->emitWarning("Warning: Forcing a different order [") + << newOrder[0] << ", " << newOrder[1] + << "] on SMEM than the register order for the operand " << opIdx + << ". Registers will be transposed before SMEM store and the pipelined " + "load for this operand will be disabled, so poor performance is " + "expected. Recommendation: consider transposing the operand in " + "global " + "memory to remove the need to transpose the tensor in registers."; + } + + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); + auto CGALayout = getCGALayout(argType.getEncoding()); + auto newLayout = NVMMASharedEncodingAttr::get( + argType.getContext(), argType.getShape(), newOrder, CGALayout, + argType.getElementType(), isMMAv5Fp4Padded); + auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); + rewriter.setInsertionPointAfterValue(arg); + return LocalAllocOp::create(rewriter, arg.getLoc(), newType, arg); +} + +static LocalAllocOp +getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) { + OpBuilder::InsertionGuard g(rewriter); + auto argType = cast(arg.getType()); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrderForMemory(argType); + + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); + auto CGALayout = getCGALayout(argType.getEncoding()); + // No swizzling for scale for now + auto newLayout = NVMMASharedEncodingAttr::get( + argType.getContext(), /*swizzlingByteWidth=*/0, + /*transposed=*/false, + /*elementBitWidth=*/argType.getElementType().getIntOrFloatBitWidth(), + /*fp4Padded=*/false, CGALayout); + auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); + rewriter.setInsertionPointAfterValue(arg); + return LocalAllocOp::create(rewriter, loc, newType, arg); +} + +SmallVector +getWarpsPerTile(DotOpInterface dotOp, const ArrayRef shape, + int version, int numWarps, + const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } +} + +static bool bwdFilter(Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isView(op) || + isa( + op); +} + +// Finds the bitwidth with which the value x is loaded +static int computeOrigBitWidth(Value x) { + SetVector slice; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + (void)getBackwardSlice(x, &slice, opt); + + // TODO: This heuristic may be a bit too coarse and may need improving + // If the chain contains a fp4 to fp16/bf16 conversion, then the original + // bitwidth is 4. + if (llvm::any_of(slice, [](Operation *op) { return isa(op); })) + return 4; + + int origBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); + for (auto op : slice) { + if (isa(op)) { + if (auto tensorTy = + dyn_cast(op->getResultTypes().front())) { + origBitWidth = + std::min(origBitWidth, tensorTy.getElementTypeBitWidth()); + } + } + } + + // If JoinOp occurred at least once, in backward layout propagation, + // the kWidth will be split in half as we pass through the JoinOp. + // Hence we divide origBitWidth by 2 here to compensate for that and + // improve our load width. + // This won't be optimal if there is a tree of multiple JoinOps, which + // would require counting the max number of JoinOp's along any path. + // + // In the future we might want to do something like trying a large kWidth, + // run layout backpropagation and see what's the contiguity that you + // get at the loads that feed into it. + if (llvm::any_of(slice, [](Operation *op) { return isa(op); })) + origBitWidth /= 2; + + return origBitWidth; +} + +namespace { + +// Common MMA encoding creation +struct MMAEncodingResult { + NvidiaMmaEncodingAttr mmaEnc; + RankedTensorType newRetType; + Value newAcc; + int versionMajor; + int versionMinor; +}; + +// Unified implementation for DotOpInterface +static MMAEncodingResult createMMAEncodingForDot(DotOpInterface dotOp, + PatternRewriter &rewriter, + int computeCapability, + int versionMajor) { + auto oldRetType = cast(dotOp.getD().getType()); + auto oldAType = cast(dotOp.getA().getType()); + + int numWarps = lookupNumWarps(dotOp); + + int versionMinor = computeCapability == 75 ? 1 : 0; + // Only MMAv2 and MMAv3 rely on computing instrShape/warpsPerTile here. + if (!(versionMajor == 2 || versionMajor == 3)) { + return {nullptr, RankedTensorType(), Value(), versionMajor, versionMinor}; + } + + auto CGALayout = getCGALayout(oldRetType.getEncoding()); + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + oldAType.getElementType(), numWarps); + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + + auto mmaEnc = NvidiaMmaEncodingAttr::get(oldRetType.getContext(), + versionMajor, versionMinor, + warpsPerTile, CGALayout, instrShape); + auto newRetType = oldRetType.cloneWithEncoding(mmaEnc); + + auto oldAcc = dotOp->getOperand(2); + auto newAcc = + ConvertLayoutOp::create(rewriter, oldAcc.getLoc(), newRetType, oldAcc); + + return {mmaEnc, newRetType, newAcc, versionMajor, versionMinor}; +} + +// Common operand conversion +static Value convertDotOperandForMMA(Value v, int opIdx, int bitwidth, + RankedTensorType newRetType, + PatternRewriter &rewriter) { + auto minType = bitwidth > 0 ? rewriter.getIntegerType(bitwidth) : v.getType(); + auto vType = cast(v.getType()); + auto newVEncoding = DotOperandEncodingAttr::get( + v.getContext(), opIdx, newRetType.getEncoding(), minType); + auto newVType = vType.cloneWithEncoding(newVEncoding); + return ConvertLayoutOp::create(rewriter, v.getLoc(), newVType, v); +} + +} // namespace + +class BlockedToMMA : public mlir::OpRewritePattern { + int computeCapability; + mutable llvm::DenseMap dotOpInstNs; + +public: + BlockedToMMA(mlir::MLIRContext *context, int computeCapability, int benefit) + : OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability < 70) + return failure(); + if (computeCapability < 80) { + dotOp.emitRemark() + << "Dot op using MMA for compute capability " << computeCapability + << " has been deprecated. It falls back to the FMA path."; + return failure(); + } + // TODO: Check data-types and SM compatibility + auto retType = dotOp.getType(); + if (!retType.getEncoding() || + mlir::isa(retType.getEncoding())) + return failure(); + + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + auto oldRetType = cast(dotOp.getType()); + + // Enable F64 MMA only on SM80/SM90 with high performance F64 tensorcore. + // Otherwise, fallback to F64 FMA for better performance. + if ((oldAType.getElementType().isF64() || + oldBType.getElementType().isF64() || + oldRetType.getElementType().isF64()) && + !(computeCapability == 80 || computeCapability == 90)) { + return failure(); + } + + auto mmaVersion = getMMAVersionSafe(computeCapability, dotOp); + auto mmaResult = + createMMAEncodingForDot(dotOp, rewriter, computeCapability, mmaVersion); + if (!(mmaResult.versionMajor >= 1 && mmaResult.versionMajor <= 3)) + return failure(); + + Operation *newDot = nullptr; + bool aFromLoad = comesFromLoadOrBlockArg(a); + bool bFromLoad = comesFromLoadOrBlockArg(b); + + if (mmaResult.versionMajor == 3) { + auto eltType = cast(a.getType()).getElementType(); + bool allowTranspose = eltType.isF16() || eltType.isBF16(); + if (!aFromLoad) { + int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth(); + a = convertDotOperandForMMA(a, 0, bitwidth, mmaResult.newRetType, + rewriter); + } else { + a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose, + /*isMMAv5Fp4Padded=*/false, + /*forceTranspose=*/false, dotOp); + } + b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose, + /*isMMAv5Fp4Padded=*/false, + /*forceTranspose=*/false, dotOp); + + newDot = triton::nvidia_gpu::WarpGroupDotOp::create( + rewriter, dotOp.getLoc(), mmaResult.newRetType, a, b, + mmaResult.newAcc, nullptr, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc(), false); + } else { + int minBitwidth = + std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + a = convertDotOperandForMMA(a, 0, minBitwidth, mmaResult.newRetType, + rewriter); + b = convertDotOperandForMMA(b, 1, minBitwidth, mmaResult.newRetType, + rewriter); + newDot = DotOp::create(rewriter, dotOp.getLoc(), mmaResult.newRetType, a, + b, mmaResult.newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + } + + rewriter.replaceOpWithNewOp(dotOp, dotOp.getType(), + newDot->getResult(0)); + return success(); + } +}; + +static bool canUseTwoCTAs(triton::DotOp dotOp) { + RankedTensorType retType = dotOp.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + // TODO: we could support 2 CTAs matmul with numCTAs > 2. + SmallVector splitNum = getCTASplitNum(retType.getEncoding()); + if (splitNum.size() != 2 || splitNum[0] != 2 || splitNum[1] != 1) + return false; + int m = retShapePerCTA[0]; + int n = retShapePerCTA[1]; + // minimum size supported by 2CTAs mmav5. + if (m < 64 || n < 32) + return false; + Value b = dotOp.getB(); + // Skip convert layouts. + while (auto cvtOp = b.getDefiningOp()) + b = cvtOp.getSrc(); + return llvm::isa_and_nonnull(b.getDefiningOp()); +} + +static DistributedEncodingTrait +replaceCGALayout(DistributedEncodingTrait layout, + const triton::gpu::CGAEncodingAttr &newCGALayout) { + if (auto blockedLayout = mlir::dyn_cast(layout)) { + return BlockedEncodingAttr::get( + layout.getContext(), blockedLayout.getSizePerThread(), + blockedLayout.getThreadsPerWarp(), blockedLayout.getWarpsPerCTA(), + blockedLayout.getOrder(), newCGALayout); + } else if (auto sliceLayout = mlir::dyn_cast(layout)) { + return SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCGALayout(sliceLayout.getParent(), newCGALayout)); + } else { + llvm::report_fatal_error("not implemented"); + return layout; + } +} + +static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + MLIRContext *ctx = b.getContext(); + while (auto cvtOp = b.getDefiningOp()) + b = cvtOp.getSrc(); + auto loadOp = b.getDefiningOp(); + assert((isa(loadOp)) && + "expected LoadOp"); + RankedTensorType bType = cast(b.getType()); + auto currentLayout = cast(bType.getEncoding()); + auto kBlock = StringAttr::get(ctx, "block"); + auto dims = standardOutDimNames(ctx, 2); + auto newCGALayout = + CGAEncodingAttr::get(ctx, LinearLayout({{kBlock, {{0, 1}}}}, dims)); + Attribute newLayout = replaceCGALayout(currentLayout, newCGALayout); + rewriter.setInsertionPoint(loadOp); + for (OpOperand &operand : loadOp->getOpOperands()) { + auto tensorType = dyn_cast(operand.get().getType()); + if (!tensorType) + continue; + Value newOperand = ConvertLayoutOp::create( + rewriter, operand.get().getLoc(), + tensorType.cloneWithEncoding(newLayout), operand.get()); + loadOp->setOperand(operand.getOperandNumber(), newOperand); + } + loadOp->getResult(0).setType(bType.cloneWithEncoding(newLayout)); + Value newB = loadOp->getResult(0); + rewriter.setInsertionPointAfter(loadOp); + auto cvt = ConvertLayoutOp::create(rewriter, b.getLoc(), bType, newB); + rewriter.replaceAllUsesExcept(newB, cvt.getResult(), cvt); + return newB; +} + +class BlockedToMMAv5 : public mlir::OpRewritePattern { + int computeCapability; + +public: + BlockedToMMAv5(mlir::MLIRContext *context, int computeCapability, int benefit) + : OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotOp dotOp, + mlir::PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + int numWarps = lookupNumWarps(dotOp); + auto CGALayout = getCGALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (versionMajor != 5) + return failure(); + Location loc = dotOp.getLoc(); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + if (std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)) >= 32 && + dotOp.getInputPrecision() != InputPrecision::TF32) + return failure(); + auto oldAType = dotOp.getA().getType(); + auto oldBType = dotOp.getB().getType(); + // NYI: PTX 13+ requires all tcgen instructions in a kernel to have a + // consistent CTA mode, disabling 2CTA mode for now. To re-enable, + // change the line below to: bool useTwoCTAs = canUseTwoCTAs(dotOp); + bool useTwoCTAs = false; + if (useTwoCTAs) { + b = splitBOperand(b, rewriter); + } + // TF32 transpose is only supported with 128 swizzle mode with 32B + // atomicity. As we currently don't support this layout we disallow + // transpose for TF32 inputs. + bool allowTranspose = !dotOp.getA().getType().getElementType().isF32(); + a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); + b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); + MLIRContext *context = dotOp->getContext(); + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps); + auto CTASplitNum = CGALayout.getCTASplitNum(); + auto bitwidth = oldRetType.getElementType().getIntOrFloatBitWidth(); + unsigned colStride = 32 / bitwidth; + Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + context, instrShape[0], instrShape[1], colStride, CTASplitNum[0], + CTASplitNum[1], useTwoCTAs); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + MemDescType accMemDescType = + MemDescType::get(oldRetType.getShape(), oldRetType.getElementType(), + accEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto newDistributedEncoding = nvidia_gpu::getDefaultLayoutForTmemLdSt( + accMemDescType, numWarps, CGALayout); + auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding); + Value cvtAcc = + ConvertLayoutOp::create(rewriter, loc, newAccType, dotOp.getOperand(2)); + auto tokType = rewriter.getType(); + auto acc = triton::nvidia_gpu::TMEMAllocOp::create( + rewriter, loc, accMemDescType, tokType, cvtAcc); + auto vTrue = arith::ConstantIntOp::create(rewriter, dotOp.getLoc(), 1, 1); + auto mma = triton::nvidia_gpu::TCGen5MMAOp::create( + rewriter, loc, tokType, a, b, acc, acc.getToken(), /*useD=*/vTrue, + /*pred=*/vTrue); + mma.setTwoCtas(useTwoCTAs); + + auto ld = triton::nvidia_gpu::TMEMLoadOp::create( + rewriter, loc, newAccType, tokType, acc, /*dep=*/mma.getToken()); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, ld); + return success(); + } +}; + +Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) { + /* + Rewrite load(scale) -> local_load(local_alloc(load(scale))). + This function does not add anything to the final IR when num_stages > 1, + but it makes it easy to apply TMEM copy rewriting later. + + Since scales are stored in TMEM for MMAv5 scaled dot, loading of scales do + not needs to be put into SMEM. But in practice, the software pipeliner puts + loading of scales into multi-buffered SMEM. At that point, the SMEM + allocation created here is eliminated. + */ + OpBuilder::InsertionGuard g(rewriter); + auto op = scale.getDefiningOp(); + Operation *loadConsumer = nullptr; + + if (!op) + return scale; + + while (!isa(op)) { + if (auto reshape = dyn_cast(op)) { + op = reshape.getSrc().getDefiningOp(); + loadConsumer = reshape; + } else if (auto trans = dyn_cast(op)) { + op = trans.getSrc().getDefiningOp(); + loadConsumer = trans; + } else if (auto cvt = dyn_cast(op)) { + op = cvt.getSrc().getDefiningOp(); + loadConsumer = cvt; + } else { + // Unrecognized pattern, bail out. In practice, this implies that MMA + // pipelining will not apply to the scaled dot op, since scales will not + // be in passed through SMEM to tc_gen5_mma_scaled. + return scale; + } + } + + auto scaleAfterLoad = op->getResult(0); + auto scaleSmemAlloc = + getSharedMemoryScale(scaleAfterLoad, rewriter, op->getLoc()); + + rewriter.setInsertionPointAfterValue(scaleSmemAlloc); + auto localLoad = LocalLoadOp::create( + rewriter, op->getLoc(), scaleAfterLoad.getType(), scaleSmemAlloc); + + rewriter.replaceAllUsesExcept(scaleAfterLoad, localLoad.getResult(), + scaleSmemAlloc); + + if (loadConsumer) { + return scale; + } else { + return localLoad; + } +} + +class ScaledBlockedToMMA : public mlir::OpRewritePattern { + int computeCapability; + +public: + ScaledBlockedToMMA(mlir::MLIRContext *context, int computeCapability, + int benefit) + : mlir::OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability != 120) + return failure(); + + auto numCTAs = lookupNumCTAs(rewriter); + if (numCTAs != 1) { + return failure(); + } + // Skip if any scale is missing. This pattern requires both scales. + if (!dotOp.getAScale() || !dotOp.getBScale()) + return failure(); + + auto aScaleType = dotOp.getAScale().getType(); + auto bScaleType = dotOp.getBScale().getType(); + + if (mlir::isa(aScaleType.getEncoding()) || + mlir::isa(bScaleType.getEncoding())) { + return failure(); + } + auto aElemType = dotOp.getAElemType(); + auto bElemType = dotOp.getBElemType(); + auto isFP8 = [&](ScaleDotElemType elemType) -> bool { + return elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2; + }; + auto isFP4 = [&](ScaleDotElemType elemType) -> bool { + return elemType == ScaleDotElemType::E2M1; + }; + // mixed precision is not supported + if (isFP8(aElemType) && isFP4(bElemType) || + isFP4(aElemType) && isFP8(bElemType)) { + return failure(); + } + + auto scaleElemType = dotOp.getAScale().getType().getElementType(); + if (scaleElemType != dotOp.getBScale().getType().getElementType()) { + return failure(); + } + + // Common MMA encoding creation + auto mmaResult = + createMMAEncodingForDot(dotOp, rewriter, computeCapability, 2); + + // Operand processing + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + + Operation *newDot = nullptr; + + // ScaledBlockedToMMA logic + int bitwidthA = oldAType.getElementType().getIntOrFloatBitWidth(); + int bitwidthB = oldBType.getElementType().getIntOrFloatBitWidth(); + int minBitwidth = std::min(bitwidthA, bitwidthB); + + Value newA = convertDotOperandForMMA(a, 0, minBitwidth, + mmaResult.newRetType, rewriter); + Value newB = convertDotOperandForMMA(b, 1, minBitwidth, + mmaResult.newRetType, rewriter); + const auto mmaWarps = mmaResult.mmaEnc.getWarpsPerCTA(); // [wM, wN] + // Convert scales to Linear layout + auto convertScale = [&](Value scale, int opIdx) -> Value { + auto ty = cast(scale.getType()); + SmallVector shape = llvm::to_vector(ty.getShape()); + MLIRContext *ctx = ty.getContext(); + auto blocked = cast(ty.getEncoding()); + + auto ll = triton::gpu::getSM120DotScaledScaleLayout( + ctx, shape, opIdx, mmaWarps, blocked.getCGALayout()); + auto newEnc = triton::gpu::LinearEncodingAttr::get(ctx, std::move(ll)); + auto newTy = RankedTensorType::get(shape, ty.getElementType(), newEnc); + return ConvertLayoutOp::create(rewriter, scale.getLoc(), newTy, scale); + }; + Value aScale = convertScale(dotOp.getAScale(), /*opIdx=*/0); + Value bScale = convertScale(dotOp.getBScale(), /*opIdx=*/1); + + newDot = triton::DotScaledOp::create( + rewriter, dotOp.getLoc(), mmaResult.newRetType, newA, newB, + mmaResult.newAcc, aScale, bScale, dotOp.getAElemType(), + dotOp.getBElemType(), dotOp.getFastMath(), dotOp.getLhsKPack(), + dotOp.getRhsKPack()); + rewriter.replaceOpWithNewOp(dotOp, dotOp.getType(), + newDot->getResult(0)); + return success(); + } +}; + +class ScaledBlockedToMMAv5 + : public mlir::OpRewritePattern { + int computeCapability; + +public: + ScaledBlockedToMMAv5(mlir::MLIRContext *context, int computeCapability, + int benefit) + : mlir::OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + if (dotOp.getAScale() == nullptr || dotOp.getBScale() == nullptr) { + return failure(); + } + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + int numWarps = lookupNumWarps(dotOp); + auto CGALayout = getCGALayout(oldRetType.getEncoding()); + if (computeCapability < 100 || computeCapability >= 120) + return failure(); + if (numWarps != 4 && numWarps != 8) + return failure(); + if (retShapePerCTA[0] < 128 || retShapePerCTA[1] < 16) + return failure(); + Location loc = dotOp.getLoc(); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + + bool IsAMixedPrecFp4 = false; + bool IsBMixedPrecFp4 = false; + bool isAFP4 = dotOp.getAElemType() == ScaleDotElemType::E2M1; + bool isBFP4 = dotOp.getBElemType() == ScaleDotElemType::E2M1; + + if (dotOp.getAElemType() != dotOp.getBElemType()) { + if (isAFP4) + IsAMixedPrecFp4 = true; + else if (isBFP4) + IsBMixedPrecFp4 = true; + } + // If we use txgen05.mma.kind.mxf864 we need to padd the fp4 operands: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-packing-formats-mxf8f6f4-smem + bool isMMAv5Fp4PaddedLhs = IsAMixedPrecFp4 || !dotOp.getLhsKPack(); + bool isMMAv5Fp4PaddedRhs = IsBMixedPrecFp4 || !dotOp.getRhsKPack(); + // For mixed-precision fp4 operands, set allowTranspose = false, to force + // the packed axis, K, to be contiguous in SMEM + a = getSharedMemoryMMAOperand(a, rewriter, 0, + /*allowTranspose=*/!isAFP4, + /*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedLhs, + /*forceTranspose=*/!dotOp.getLhsKPack(), + dotOp); + b = getSharedMemoryMMAOperand(b, rewriter, 1, + /*allowTranspose=*/!isBFP4, + /*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedRhs, + /*forceTranspose=*/!dotOp.getRhsKPack(), + dotOp); + + MLIRContext *context = dotOp->getContext(); + unsigned m = 128; + unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1]; + + auto CTASplitNum = CGALayout.getCTASplitNum(); + auto bitwidth = oldRetType.getElementType().getIntOrFloatBitWidth(); + unsigned colStride = 32 / bitwidth; + Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + context, m, n, colStride, CTASplitNum[0], CTASplitNum[1], false); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + MemDescType accMemDescType = + MemDescType::get(oldRetType.getShape(), oldRetType.getElementType(), + accEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto newDistributedEncoding = nvidia_gpu::getDefaultLayoutForTmemLdSt( + accMemDescType, numWarps, CGALayout); + auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding); + Value cvtAcc = + ConvertLayoutOp::create(rewriter, loc, newAccType, dotOp.getOperand(2)); + auto tokType = rewriter.getType(); + auto acc = triton::nvidia_gpu::TMEMAllocOp::create( + rewriter, loc, accMemDescType, tokType, cvtAcc); + + RankedTensorType oldScaleAType = dotOp.getAScale().getType(); + RankedTensorType oldScaleBType = dotOp.getBScale().getType(); + + Attribute scaleEncoding = + triton::nvidia_gpu::TensorMemoryScalesEncodingAttr::get( + context, CTASplitNum[0], CTASplitNum[1]); + MemDescType scaleAType = triton::gpu::MemDescType::get( + oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + MemDescType scaleBType = triton::gpu::MemDescType::get( + oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + Attribute scaleALayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + scaleAType, numWarps, getCGALayout(oldScaleAType.getEncoding())); + Attribute scaleBLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + scaleBType, numWarps, getCGALayout(oldScaleBType.getEncoding())); + RankedTensorType newScaleAType = + oldScaleAType.cloneWithEncoding(scaleALayout); + RankedTensorType newScaleBType = + oldScaleBType.cloneWithEncoding(scaleBLayout); + + auto lhsScale = addSmemStageToScaleLoad(dotOp.getAScale(), rewriter); + auto rhsScale = addSmemStageToScaleLoad(dotOp.getBScale(), rewriter); + + Value newScaleA = + ConvertLayoutOp::create(rewriter, loc, newScaleAType, lhsScale); + Value newScaleB = + ConvertLayoutOp::create(rewriter, loc, newScaleBType, rhsScale); + + // We don't need to track memory dependencies for the scale operands since + // they are not pipelined. + auto scaleA = triton::nvidia_gpu::TMEMAllocOp::create( + rewriter, loc, scaleAType, /*token=*/Type(), newScaleA); + auto scaleB = triton::nvidia_gpu::TMEMAllocOp::create( + rewriter, loc, scaleBType, /*token=*/Type(), newScaleB); + + auto vTrue = arith::ConstantIntOp::create(rewriter, dotOp.getLoc(), 1, 1); + auto mmaOp = triton::nvidia_gpu::TCGen5MMAScaledOp::create( + rewriter, loc, tokType, a, b, acc.getResult(), acc.getToken(), + scaleA.getResult(), scaleB.getResult(), dotOp.getAElemType(), + dotOp.getBElemType(), + /*useD=*/vTrue, /*pred=*/vTrue); + + auto ld = triton::nvidia_gpu::TMEMLoadOp::create( + rewriter, loc, newAccType, tokType, acc, mmaOp.getToken()); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, ld); + return success(); + } +}; +} // namespace + +static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + Type operandElType = + cast(operand.getType()).getElementType(); + if (type::isFloat8(operandElType)) { + return FpToFpOp::create(builder, loc, tensorPromotedType, operand); + } + return arith::ExtFOp::create(builder, loc, tensorPromotedType, operand); +} + +static bool mmav2SupportsFp8Operands(int computeCapability) { + // promote operands for sm < 89 since fp8 mma is not natively supported + // although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and + // sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has + // hardware support for fp8 operands w/ mmav2. + return computeCapability == 89 || computeCapability == 120; +} + +// promote operands of dot op if the existing combination is not natively +// supported. +static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { + mod.walk([=](DotOp dotOp) -> void { + auto D = dotOp.getD(); + OpBuilder builder(dotOp); + Type AElType = dotOp.getA().getType().getElementType(); + Type promoteType; + NvidiaMmaEncodingAttr mmaLayout = + dyn_cast(D.getType().getEncoding()); + if (mmaLayout) { + bool isNativeFP8 = llvm::isa(AElType); + // promote to f16 unless there's hardware support for fp8 operands + if (!isNativeFP8 || + (isNativeFP8 && (mmav2SupportsFp8Operands(computeCapability) || + mmaLayout.isHopper()))) + return; + promoteType = builder.getF16Type(); + } else { + // FMA case. + Type AElType = dotOp.getA().getType().getElementType(); + Type DElType = D.getType().getElementType(); + if (AElType == DElType) + return; + promoteType = DElType; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); +} + +// Transpose scaled_dot ops that have a scale on lhs. +static void transposeDotOp(DotScaledOp dotOp) { + OpBuilder builder(dotOp); + Value lhs = dotOp.getA(); + std::array transOrder = {1, 0}; + Value lhsTransposed = TransOp::create(builder, lhs.getLoc(), lhs, transOrder); + Value rhs = dotOp.getB(); + Value rhsTransposed = TransOp::create(builder, rhs.getLoc(), rhs, transOrder); + Value c = dotOp.getC(); + Value cTransposed = TransOp::create(builder, c.getLoc(), c, transOrder); + Value result = DotScaledOp::create( + builder, dotOp.getLoc(), cTransposed.getType(), rhsTransposed, + lhsTransposed, cTransposed, dotOp.getBScale(), dotOp.getAScale(), + dotOp.getBElemType(), dotOp.getAElemType(), dotOp.getFastMath()); + Operation *transposedResult = + TransOp::create(builder, result.getLoc(), result, transOrder); + dotOp.replaceAllUsesWith(transposedResult); + dotOp.erase(); +} + +static void transposeDots(ModuleOp m) { + // TODO: extend to regular dot when it is profitable. For instance when we may + // want to use rhs from register for mmav3. + SmallVector toTranspose; + m.walk([&](DotScaledOp dotOp) -> void { + if (dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr) + toTranspose.push_back(dotOp); + }); + for (DotScaledOp dotOp : toTranspose) { + transposeDotOp(dotOp); + } +} + +#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUAccelerateMatmulPass + : public impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass> { +public: + using impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass>::TritonGPUAccelerateMatmulBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + auto computeCapability = getNVIDIAComputeCapability(m); + // We could do this generically if we manage to improve the heuristics + // reverted in these two PRs https://github.com/triton-lang/triton/pull/5834 + // https://github.com/triton-lang/triton/pull/5837 + transposeDots(m); + + mlir::RewritePatternSet patterns(context); + constexpr int benefitDefault = 1; + constexpr int benefitMMAv5 = 10; + constexpr int benefitSM120 = 10; + + patterns.add(context, computeCapability, benefitDefault); + patterns.add(context, computeCapability, benefitSM120); + populateDecomposeScaledBlockedPatterns(patterns, benefitDefault); + patterns.add( + context, computeCapability, benefitMMAv5); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + // Now that we have picked the mma type, decompose dot that are not natively + // supported. + decomposeMixedModeDotOp(m, computeCapability); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..962ef6a1fe --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,56 @@ +add_triton_library(TritonGPUTransforms + AccelerateMatmul.cpp + Coalesce.cpp + F32DotTC.cpp + FuseNestedLoops.cpp + CombineTensorSelectAndIf.cpp + DecomposeScaledBlocked.cpp + HoistTMEMAlloc.cpp + ReduceDataDuplication.cpp + OptimizeAccumulatorInit.cpp + OptimizeDotOperands.cpp + OptimizeThreadLocality.cpp + Pipeliner/AssignLatencies.cpp + Pipeliner/LowerLoops.cpp + Pipeliner/MMAv5PipelineUtility.cpp + Pipeliner/ScheduleLoops.cpp + Pipeliner/WGMMAPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/TestPipelineLowerLoop.cpp + Pipeliner/SoftwarePipeliner.cpp + Pipeliner/TMAStoresPipeline.cpp + Pipeliner/MMAv5PipelineUtility.cpp + Pipeliner/PipeliningUtility.cpp + Pipeliner/Schedule.cpp + Prefetch.cpp + RemoveLayoutConversions.cpp + ReorderInstructions.cpp + CoalesceAsyncCopy.cpp + Utility.cpp + CoalesceUtils.cpp + LayoutPropagationUtility.cpp + WarpSpecialization/AutomaticWarpSpecialization.cpp + WarpSpecialization/Partition.cpp + WarpSpecialization/OptimizePartitionWarps.cpp + WarpSpecialization/PartitionBuilder.cpp + WarpSpecialization/PartitionLoops.cpp + WarpSpecialization/PartitionScheduling.cpp + WarpSpecialization/PartitionSchedulingUtility.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonTransforms + TritonGPUIR + TritonNvidiaGPUIR + NVWSIR + NVWSTransforms + TritonToTritonGPU + TritonInstrumentIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 0000000000..6ba84784df --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,125 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Descriptor load/stores don't need to consider L1 coalescing but the +// destination layout will affect the shared memory load/store generated. So we +// still want to allow vectorization for the src/destination layout up to +// 16bytes. +static Attribute pickDescriptorLoadStoreLayout(int numWarps, int threadsPerWarp, + RankedTensorType type) { + auto shapePerCTA = triton::gpu::getShapePerCTA(type); + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + int numElemsPerThread = std::max(numElems / numThreads, 1); + + int maxVectorSize = 128 / type.getElementTypeBitWidth(); + + int vectorSize = std::min(numElemsPerThread, maxVectorSize); + SmallVector sizePerThread(type.getRank(), 1); + sizePerThread.back() = vectorSize; + + SmallVector order = + getMatrixOrder(type.getRank(), /*rowMajor*/ true); + auto cgaLayout = triton::gpu::getCGALayout(type.getEncoding()); + + Attribute layout = triton::gpu::BlockedEncodingAttr::get( + type.getContext(), type.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, cgaLayout); + return layout; +} + +static void pickDescriptorLoadStoreLayout( + ModuleOp moduleOp, llvm::MapVector &layoutMap) { + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(moduleOp); + moduleOp.walk([&](Operation *op) { + int numWarps = lookupNumWarps(op); + if (auto load = dyn_cast(op)) { + if (load->getNumResults() == 1) + layoutMap[op] = pickDescriptorLoadStoreLayout( + numWarps, threadsPerWarp, + cast(load->getResult(0).getType())); + } + if (auto store = dyn_cast(op)) { + layoutMap[op] = pickDescriptorLoadStoreLayout(numWarps, threadsPerWarp, + store.getSrc().getType()); + } + }); +} + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return tensorType.cloneWithEncoding(encoding); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(moduleOp); + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + int numWarps = lookupNumWarps(curr); + + auto tensorType = cast(ptr.getType()); + CGAEncodingAttr cgaLayout = getCGALayout(tensorType.getEncoding()); + SmallVector shapePerCTA = getShapePerCTA(tensorType); + auto layout = + buildCoalescedEncoding(axisInfoAnalysis, curr, numWarps, + threadsPerWarp, cgaLayout, shapePerCTA); + layoutMap[curr] = layout; + }); + + // Also pick a layout for descriptor load/store ops. + pickDescriptorLoadStoreLayout(moduleOp, layoutMap); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + convertDistributedOpEncoding(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp new file mode 100644 index 0000000000..a641f6bd5b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -0,0 +1,214 @@ +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static Value convertValueLayout(Value src, Attribute enc, + PatternRewriter &rewriter) { + auto ty = cast(src.getType()); + auto newTy = ty.cloneWithEncoding(enc); + auto cvt = ConvertLayoutOp::create(rewriter, src.getLoc(), newTy, src); + return cvt.getResult(); +} + +static void retargetCopyOperandsToEncoding( + AsyncCopyGlobalToLocalOp copyOp, Attribute newEncoding, + ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternRewriter &rewriter) { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + + // insert cvt's after src, mask, and other + src = convertValueLayout(src, newEncoding, rewriter); + if (mask) + mask = convertValueLayout(mask, newEncoding, rewriter); + if (other) + other = convertValueLayout(other, newEncoding, rewriter); + + unsigned contiguity = axisInfoAnalysis.getContiguity(src); + if (mask) + contiguity = + std::min(contiguity, axisInfoAnalysis.getMaskAlignment(mask)); + + rewriter.modifyOpInPlace(copyOp, [&]() { + copyOp.getSrcMutable().assign(src); + if (mask) + copyOp.getMaskMutable().assign(mask); + if (other) + copyOp.getOtherMutable().assign(other); + copyOp.setContiguity(contiguity); + }); +} + +// This pass currently only applies if the following are all true... +// 1) Operand A for WGMMA is to be loaded in registers +// 2) We upcast operand A in registers before the WGMMA +// (downcasting is not yet supported) +// 3) Pipelining is enabled for loading A +// +// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding +// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if +// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread +// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two +// 8-byte-cp.async's for each contiguous 16B global data owned by each +// thread. This breaks coalescing (i.e. results 2x the minimum required +// transactions). +// +// This issue occurs for cp.async because it combines load and store into one +// instruction. The fix is to clip each dim of sizePerThread by shared vec, so +// that the vectorization of load and store are equal along the contiguous +// dimension. In the above example, each thread will then only own 8B contiguous +// global data. +struct ClipAsyncCopySizePerThread + : public OpRewritePattern { + ModuleAxisInfoAnalysis &axisInfoAnalysis; + using OpRewritePattern::OpRewritePattern; + ClipAsyncCopySizePerThread(ModuleAxisInfoAnalysis &axisInfoAnalysis, + MLIRContext *context) + : OpRewritePattern(context), axisInfoAnalysis(axisInfoAnalysis) {} + + LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, + PatternRewriter &rewriter) const override { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(copyOp.getResult().getType()); + auto blockedEnc = dyn_cast(srcTy.getEncoding()); + if (!blockedEnc) + return rewriter.notifyMatchFailure(copyOp, + "src must be of blocked encoding"); + auto sharedEnc = dyn_cast(dstTy.getEncoding()); + if (!sharedEnc) + return failure(); + auto sharedVec = sharedEnc.getVec(); + + // obtain max contiguous copy size + // Note this can be further optimized, as copyContigSize can be even + // smaller when lowering, depending on contiguity and mask alignment + // (see AsyncCopyGlobalToLocalOpConversion) + LinearLayout regLayout = triton::gpu::toLinearLayout(srcTy); + LinearLayout sharedLayout = triton::gpu::toLinearLayout(dstTy); + auto copyContigSize = + regLayout.invertAndCompose(sharedLayout).getNumConsecutiveInOut(); + + // obtain block sizePerThread along contig dim + auto contigPerThread = getContigPerThread(srcTy); + auto blockContigSize = contigPerThread[blockedEnc.getOrder()[0]]; + + if (blockContigSize <= copyContigSize) + return rewriter.notifyMatchFailure( + copyOp, + "blocked sizePerThread along contiguous dim must be greater than the " + "max contiguous copy size "); + + contigPerThread[blockedEnc.getOrder()[0]] = copyContigSize; + + // obtain new blockedEnc based on clipped sizePerThread + auto mod = copyOp->getParentOfType(); + int numWarps = lookupNumWarps(copyOp); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto newBlockEnc = BlockedEncodingAttr::get( + copyOp.getContext(), srcTy.getShape(), contigPerThread, + blockedEnc.getOrder(), numWarps, threadsPerWarp, + blockedEnc.getCGALayout()); + + retargetCopyOperandsToEncoding(copyOp, newBlockEnc, axisInfoAnalysis, + rewriter); + + return success(); + } +}; + +// For cheap loads we usually pick the layout based on users but when converting +// to async_cp the layout of the copy is independent of the layout of the users +// so picking a coalesced layout is better. +struct CoalesceCheapAsyncCopyGlobalToLocal + : public OpRewritePattern { + ModuleAxisInfoAnalysis &axisInfoAnalysis; + DenseMap &coalescedAsyncCopyMap; + using OpRewritePattern::OpRewritePattern; + CoalesceCheapAsyncCopyGlobalToLocal( + ModuleAxisInfoAnalysis &axisInfoAnalysis, + DenseMap &coalescedAsyncCopyMap, + MLIRContext *context) + : OpRewritePattern(context), axisInfoAnalysis(axisInfoAnalysis), + coalescedAsyncCopyMap(coalescedAsyncCopyMap) {} + + LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, + PatternRewriter &rewriter) const override { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + RankedTensorType srcTy = cast(src.getType()); + auto dstTy = cast(copyOp.getResult().getType()); + int numWarps = triton::gpu::lookupNumWarps(copyOp); + auto mod = copyOp->getParentOfType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = triton::gpu::getNumCTAs(dstTy.getEncoding()); + int64_t size = srcTy.getNumElements(); + // Assume the expensive copies are already coalesced. + // Skip dtype smaller than 32 bits to avoid problems with contiguity. + if (size >= numWarps * threadsPerWarp || + dstTy.getElementTypeBitWidth() < 32) + return failure(); + auto shapePerCTA = triton::gpu::getShapePerCTA(dstTy); + auto cgaLayout = triton::gpu::getCGALayout(dstTy.getEncoding()); + + auto newEnc = coalescedAsyncCopyMap[copyOp]; + if (newEnc == nullptr || newEnc == srcTy.getEncoding()) + return failure(); + + retargetCopyOperandsToEncoding(copyOp, newEnc, axisInfoAnalysis, rewriter); + + return success(); + } +}; + +struct CoalesceAsyncCopyPass + : impl::TritonGPUCoalesceAsyncCopyBase { + using Base::Base; + + void runOnOperation() override { + ModuleOp m = getOperation(); + triton::ModuleAxisInfoAnalysis axisInfoAnalysis(m); + // Collect the coalesced encoding first as changing the IR invalidates the + // axis analysis. + DenseMap coalescedAsyncCopyMap; + m.walk([&](AsyncCopyGlobalToLocalOp copyOp) { + auto dstTy = cast(copyOp.getResult().getType()); + int numWarps = triton::gpu::lookupNumWarps(copyOp); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(m); + int numCTAs = triton::gpu::getNumCTAs(dstTy.getEncoding()); + auto cgaLayout = triton::gpu::getCGALayout(dstTy.getEncoding()); + auto shapePerCTA = triton::gpu::getShapePerCTA(dstTy); + coalescedAsyncCopyMap[copyOp] = + buildCoalescedEncoding(axisInfoAnalysis, copyOp, numWarps, + threadsPerWarp, cgaLayout, shapePerCTA); + }); + + MLIRContext *context = &getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add(axisInfoAnalysis, context); + patterns.add( + axisInfoAnalysis, coalescedAsyncCopyMap, context); + + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp new file mode 100644 index 0000000000..4705fb18d8 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp @@ -0,0 +1,96 @@ + + +#include "triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton::gpu { +BlockedEncodingAttr +buildCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + triton::gpu::CGAEncodingAttr cgaLayout, + SmallVector shapePerCTA) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = getOrderFromContiguity(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::getSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = getOrderFromContiguity( + axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = + getNumElementsPerThread(op, order, axisInfoAnalysis, shapePerCTA); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = getNumElementsPerThread( + opSameOrder, order, axisInfoAnalysis, shapePerCTA); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + perThread = std::min( + perThread, + getNumElementsPerThread(op, order, axisInfoAnalysis, shapePerCTA)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + return BlockedEncodingAttr::get(op->getContext(), refTensorType.getShape(), + sizePerThread, order, numWarps, + threadsPerWarp, cgaLayout); +} +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp new file mode 100644 index 0000000000..608e65a153 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -0,0 +1,176 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// The user of select maybe inside either the ThenRegion or ElseRegion of +/// the scf.if. So, canonicalize user of select in scf.if first. +static void canonicalizeSelectUsersInSCFIf(ModuleOp input) { + llvm::MapVector, SmallVector> + usersNeedreplaced; + input.walk([&](arith::SelectOp selectOp) { + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + Value trueVal = selectOp.getOperand(1); + Value falseVal = selectOp.getOperand(2); + Value resVal = selectOp.getResult(); + for (auto *condUser : condition.getUsers()) { + if (!llvm::isa(condUser)) + continue; + scf::IfOp ifOp = llvm::cast(condUser); + for (auto *resUser : resVal.getUsers()) { + if (ifOp->isProperAncestor(resUser)) { + if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) != + nullptr) { + // The user is inside the ThenRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back( + resUser); + } else { + // The user is inside the ElseRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back( + resUser); + } + } + } + } + }); + + // Replace the operand of user. + for (auto [replacedSrcAndDst, users] : + llvm::make_early_inc_range(usersNeedreplaced)) { + Value srcVal = replacedSrcAndDst.first; + Value dstVal = replacedSrcAndDst.second; + for (Operation *user : llvm::make_early_inc_range(users)) { + srcVal.replaceUsesWithIf( + dstVal, [&](OpOperand &use) { return use.getOwner() == user; }); + } + } +} + +/// Return true if the select could be merged into the If without breaking SSA +/// rules. +static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, + DominanceInfo &dom) { + // If needs to be dominated by the select. + if (!dom.dominates(selectOp.getOperation(), ifOp.getOperation())) { + return false; + } + // If needs to dominate all the select's users. + for (auto user : selectOp.getResult().getUsers()) { + if (!dom.dominates(ifOp, user)) { + return false; + } + } + return true; +} + +class CombineTensorSelectAndIfPass + : public impl::TritonGPUCombineTensorSelectAndIfBase< + CombineTensorSelectAndIfPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + canonicalizeSelectUsersInSCFIf(m); + + // Go over the arith.select ops, look if there is an if + // with the same condition. + DominanceInfo dom(m); + llvm::MapVector> selectToIf; + m.walk([&](arith::SelectOp selectOp) { + // Apply only to selects with a tensor result. Scalars are cheap enough to + // predicate. + if (!isa(selectOp.getResult().getType())) + return; + // Look if there is an if in the same block, with the same condition. + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + SetVector conditionUsers(condition.getUsers().begin(), + condition.getUsers().end()); + // sort the users in topological order. + conditionUsers = mlir::topologicalSort(conditionUsers); + // Get condition's users + for (Operation *user : conditionUsers) { + auto ifOp = dyn_cast(user); + if (!ifOp || ifOp->getBlock() != parentBlock) + continue; + if (canMergeIntoIf(selectOp, ifOp, dom)) { + selectToIf[ifOp].push_back(selectOp); + break; + } + } + }); + + for (auto [ifOp, selectOps] : selectToIf) { + // Add new return value to the if (and create else block if necessary), + // then yield the select value in the then block and the else block. + OpBuilder builder(ifOp); + auto loc = ifOp.getLoc(); + // Create an scf::IfOp with extra return value. + SmallVector newResultTypes = {ifOp.getResultTypes().begin(), + ifOp.getResultTypes().end()}; + for (arith::SelectOp selectOp : selectOps) { + newResultTypes.push_back(selectOp.getResult().getType()); + } + auto newIfOp = scf::IfOp::create(builder, loc, newResultTypes, + ifOp.getCondition(), /*hasElse*/ true); + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + if (ifOp.elseBlock()) { + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + } else { + // Create an empty yield + auto builder = newIfOp.getElseBodyBuilder(); + auto yieldOp = scf::YieldOp::create(builder, loc); + } + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + for (arith::SelectOp selectOp : selectOps) { + Value thenValue = selectOp.getTrueValue(); + Value elseValue = selectOp.getFalseValue(); + ifYieldOperands.push_back(thenValue); + elseYieldOperands.push_back(elseValue); + } + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + builder.setInsertionPoint(yield); + scf::YieldOp::create(builder, loc, operands); + yield.erase(); + }; + updateYield(newIfOp.thenYield(), ifYieldOperands); + updateYield(newIfOp.elseYield(), elseYieldOperands); + + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + // Replace the select with the new return value. + for (arith::SelectOp selectOp : selectOps) { + selectOp.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + selectOp.erase(); + } + + ifOp.erase(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp new file mode 100644 index 0000000000..509b815eb3 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp @@ -0,0 +1,261 @@ +#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h" + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +SmallVector DecomposeScaledBlocked::getTransposeOrder(int rank) { + assert(rank >= 2); + auto transOrder = llvm::to_vector<2>(llvm::seq(rank - 2)); + transOrder.push_back(rank - 1); + transOrder.push_back(rank - 2); + return transOrder; +} + +LogicalResult +DecomposeScaledBlocked::matchAndRewrite(DotScaledOp scaledDotOp, + PatternRewriter &rewriter) const { + if (isa_and_nonnull( + scaledDotOp.getResult().getType().getEncoding())) + return failure(); + + // TODO: add support for m/n packed formats. + if (!scaledDotOp.getLhsKPack() || !scaledDotOp.getRhsKPack()) + return failure(); + // Types + auto computeType = getComputeType(scaledDotOp.getAElemType(), + scaledDotOp.getBElemType(), rewriter); + + auto scaledA = scaleArg(rewriter, scaledDotOp, 0, computeType); + scaledA = cvtDotOperand(rewriter, scaledDotOp, 0, scaledA); + auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType); + scaledB = cvtDotOperand(rewriter, scaledDotOp, 1, scaledB); + auto newDot = DotOp::create(rewriter, scaledDotOp.getLoc(), scaledA, scaledB, + scaledDotOp.getC()); + + rewriter.replaceOpWithNewOp(scaledDotOp, + scaledDotOp.getType(), newDot); + return success(); +} + +FloatType +DecomposeScaledBlocked::getComputeType(ScaleDotElemType aType, + ScaleDotElemType bType, + PatternRewriter &rewriter) const { + if (aType == ScaleDotElemType::FP16 || bType == ScaleDotElemType::FP16) + return rewriter.getF16Type(); + return rewriter.getBF16Type(); +} + +TypedValue +DecomposeScaledBlocked::scaleTo16(PatternRewriter &rewriter, + TypedValue scale, + FloatType computeType) const { + auto loc = scale.getLoc(); + auto scaleTy = scale.getType(); + assert(computeType == rewriter.getBF16Type() || + computeType == rewriter.getF16Type()); + + // Choose an fp type that can fit the scale value. + FloatType largeFpType = computeType == rewriter.getF16Type() + ? rewriter.getF32Type() + : computeType; + int intWidth = largeFpType.getIntOrFloatBitWidth(); + auto intType = rewriter.getIntegerType(intWidth); + + auto zexted = + arith::ExtUIOp::create(rewriter, loc, scaleTy.clone(intType), scale); + // getFpMantissaWidth() returns the number of bits in the mantissa plus the + // sign bit! + int shiftValue = largeFpType.getFPMantissaWidth() - 1; + auto shiftConst = + arith::ConstantIntOp::create(rewriter, loc, shiftValue, intWidth); + auto shift = + SplatOp::create(rewriter, loc, scaleTy.clone(intType), shiftConst); + auto shlRes = arith::ShLIOp::create(rewriter, loc, zexted, shift); + Value scaleFP = + BitcastOp::create(rewriter, loc, scaleTy.clone(largeFpType), shlRes); + if (largeFpType != computeType) { + scaleFP = arith::TruncFOp::create(rewriter, loc, scaleTy.clone(computeType), + scaleFP); + } + return cast>(scaleFP); +} + +TypedValue DecomposeScaledBlocked::broadcastScale( + PatternRewriter &rewriter, DotScaledOp scaledDotOp, ModuleOp mod, + TypedValue scale, int dim) const { + auto *ctx = rewriter.getContext(); + auto loc = scale.getLoc(); + auto scaleTy = scale.getType(); + auto rank = scaleTy.getRank(); + // 2.1) Expand dims along the last dimension + { + // 2.1.1) Find default encoding for ExpandDims + auto shape = to_vector(scaleTy.getShape()); + shape.insert(shape.end(), 1); + auto nWarps = lookupNumWarps(scaledDotOp); + auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto numCTAs = TritonGPUDialect::getNumCTAs(mod); + auto blockedEnc = + getDefaultBlockedEncoding(ctx, shape, nWarps, threadsPerWarp, numCTAs); + // 2.1.2) Cast scale16 to SliceEncoding + auto sliceEnc = SliceEncodingAttr::get(ctx, rank, blockedEnc); + auto sliceType = scaleTy.cloneWithEncoding(sliceEnc); + scale = ConvertLayoutOp::create(rewriter, loc, sliceType, scale); + } + auto expandScale = ExpandDimsOp::create(rewriter, loc, scale, rank); + // 2.2) Broadcast the dimension to size 32 + auto scaleShape = to_vector(scaleTy.getShape()); + scaleShape.push_back(32); + auto broadcastScale = BroadcastOp::create( + rewriter, loc, expandScale.getType().clone(scaleShape), expandScale); + // 2.3) Transpose the dimension to the scaled dimension + auto transposeOrder = llvm::to_vector(llvm::seq(rank)); + transposeOrder.insert(transposeOrder.begin() + dim + 1, rank); + auto transposedScale = + TransOp::create(rewriter, loc, broadcastScale, transposeOrder); + // 2.4) Reshape to the shape of v + scaleShape.pop_back(); + scaleShape[dim] *= 32; + auto reshapeScale = + ReshapeOp::create(rewriter, loc, scaleShape, transposedScale); + return reshapeScale; +} + +TypedValue DecomposeScaledBlocked::maskNan( + PatternRewriter &rewriter, DotScaledOp scaledDotOp, + TypedValue mxfp, TypedValue scale, + int dim) const { + // Skip NaN checks if fastMath + if (scaledDotOp.getFastMath()) + return mxfp; + + // Implement tl.where(scale == 0xFF, float("nan"), mxfp) + auto loc = scale.getLoc(); + auto mod = scaledDotOp->getParentOfType(); + + // Scale is NaN + auto scaleTy = scale.getType(); + auto constFF = arith::ConstantOp::create( + rewriter, loc, scaleTy, + DenseElementsAttr::get(scaleTy, + APInt(scaleTy.getElementTypeBitWidth(), 0xff))); + auto scaleIsNan = cast>( + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, scale, + constFF) + .getResult()); + auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim); + // Make scale is NaN compatible with mxfp + auto condTy = cond.getType(); + condTy = condTy.cloneWithEncoding(mxfp.getType().getEncoding()); + cond = ConvertLayoutOp::create(rewriter, loc, condTy, cond); + + // Create NaN + auto mxfpTy = mxfp.getType(); + auto nan = APFloat::getNaN( + cast(mxfpTy.getElementType()).getFloatSemantics()); + auto constNan = arith::ConstantOp::create( + rewriter, loc, mxfpTy, DenseElementsAttr::get(mxfpTy, nan)); + + auto result = arith::SelectOp::create(rewriter, loc, cond, constNan, mxfp); + return cast>(result.getResult()); +} + +TypedValue +DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, int opIdx, + FloatType computeType) const { + auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB(); + auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale(); + auto isFp4 = + ScaleDotElemType::E2M1 == + (opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType()); + auto fastMath = scaledDotOp.getFastMath(); + + auto loc = v.getLoc(); + auto rank = v.getType().getRank(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + + // 0) Upcast value to computeType (fp16/bf16) + if (isFp4) { + // We always pack along the fastest moving dimension, kDim + v = Fp4ToFpOp::create(rewriter, loc, v, computeType, kDim); + } else { + auto vType16 = v.getType().clone(computeType); + v = cast>( + FpToFpOp::create(rewriter, loc, vType16, v).getResult()); + } + if (!scale) + return v; + + // 1) Cast scale to fp16/bf16, broadcast it and convert its layout + auto reshapeScale = extendAndBroadcastScale(rewriter, scaledDotOp, scale, + computeType, v.getType(), opIdx); + + // 2) Multiply + auto mxfp = cast>( + arith::MulFOp::create(rewriter, loc, v, reshapeScale).getResult()); + + // 3) If the scale is NaN, return NaN, else return the scaled value. + return maskNan(rewriter, scaledDotOp, mxfp, scale, kDim); +} + +TypedValue DecomposeScaledBlocked::extendAndBroadcastScale( + PatternRewriter &rewriter, DotScaledOp scaledDotOp, + TypedValue &scale, FloatType computeType, + RankedTensorType dstType, int opIdx) const { + auto loc = scale.getLoc(); + auto mod = scaledDotOp->getParentOfType(); + auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB(); + auto rank = v.getType().getRank(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + + // For some weird reason, we take the scale with shape as if it were coming + // from the lhs even when it's the rhs. In a normal world, we should accept + // this parameter transposed, as we do with the mxfp. + // + // Notice: this is an inplace change. + if (opIdx == 1) { + auto order = getTransposeOrder(rank); + scale = TransOp::create(rewriter, loc, scale, order); + } + + // 1) Cast scale to compute type (fp16/bf16) + auto scale16 = scaleTo16(rewriter, scale, computeType); + + // 2) Broadcast scale to the same shape as v and convert the layout + auto reshapeScale = broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim); + return ConvertLayoutOp::create(rewriter, loc, dstType, reshapeScale); +} + +TypedValue +DecomposeScaledBlocked::cvtDotOperand(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, int opIdx, + TypedValue v) const { + auto *ctx = rewriter.getContext(); + auto retEnc = scaledDotOp.getType().getEncoding(); + auto vType = v.getType(); + auto encoding = + DotOperandEncodingAttr::get(ctx, opIdx, retEnc, vType.getElementType()); + auto retTy = vType.cloneWithEncoding(encoding); + return ConvertLayoutOp::create(rewriter, v.getLoc(), retTy, v); +} + +void populateDecomposeScaledBlockedPatterns(RewritePatternSet &patterns, + int benefit) { + patterns.add(patterns.getContext(), benefit); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp new file mode 100644 index 0000000000..9c0cfc3fba --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -0,0 +1,241 @@ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir::triton::gpu { + +#define GEN_PASS_DEF_TRITONGPUF32DOTTC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +template +auto convertValue(Value value, const FloatType &scalarToType, + PatternRewriter &rewriter) -> mlir::Value { + auto fromType = cast(value.getType()); + auto toType = fromType.cloneWith(std::nullopt, scalarToType); + return T::create(rewriter, value.getLoc(), toType, value).getResult(); +} + +auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) + -> llvm::SmallVector { + llvm::SmallVector splitInputs; + for (unsigned i = 0; i < N; ++i) { + Value inputAsBF16 = + convertValue(input, rewriter.getBF16Type(), rewriter); + if (i != N - 1) { + Value inputAsF32 = convertValue( + inputAsBF16, rewriter.getF32Type(), rewriter); + input = + arith::SubFOp::create(rewriter, input.getLoc(), input, inputAsF32); + } + splitInputs.push_back(inputAsBF16); + } + return splitInputs; +} + +bool isF32(Value operand) { + return cast(operand.getType()).getElementType().isF32(); +}; + +Value zeroLike(Value c, PatternRewriter &rewriter) { + return SplatOp::create( + rewriter, c.getLoc(), c.getType(), + arith::ConstantOp::create(rewriter, c.getLoc(), + rewriter.getF32FloatAttr(0))); +}; + +Value dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, + InputPrecision precision = InputPrecision::IEEE, + uint32_t maxNumImpreciseAcc = 0) { + return DotOp::create(rewriter, lhs.getLoc(), lhs, rhs, acc, precision, + maxNumImpreciseAcc); +}; + +Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) { + auto nans = arith::CmpFOp::create(rewriter, value.getLoc(), + arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value, rewriter); + return arith::SelectOp::create(rewriter, value.getLoc(), nans, zero, value); +}; + +unsigned getBF16Count(triton::InputPrecision precision) { + switch (precision) { + default: + return 0; + case InputPrecision::BF16x3: + // BF16x3 only needs the first 2 values derived from splitting an F32 + return 2; + case InputPrecision::BF16x6: + return 3; + } +} + +bool isMusaTarget(Operation *op) { + auto module = op ? op->getParentOfType() : nullptr; + auto targetAttr = + module ? module->getAttrOfType(AttrTargetName) : nullptr; + return targetAttr && targetAttr.getValue().starts_with("musa:"); +} + +// Implements 3xBF16 https://arxiv.org/abs/1904.06376 +// See also +// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 +// As well as +// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 +struct BF16xN : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + // BF16 indices and count + const unsigned hi = 0; + const unsigned mid = 1; + const unsigned lo = 2; + const unsigned N = getBF16Count(dotOp.getInputPrecision()); + + if (!isF32(dotOp.getA()) || !isF32(dotOp.getB()) || !N) + return failure(); + + // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator + const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter); + const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter); + auto result = zeroLike(dotOp.getC(), rewriter); + + switch (dotOp.getInputPrecision()) { + default: + assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); + return failure(); + + // clang-format off + // NOTE: 9 dots possible; handled like so if not for lack of speedup: + // case InputPrecision::BF16x9: + // result = dot(lhs_parts[lo], rhs_parts[lo], result, rewriter); + // result = dot(lhs_parts[mid], rhs_parts[lo], result, rewriter); + // result = dot(lhs_parts[lo], rhs_parts[mid], result, rewriter); + // clang-format on + + case InputPrecision::BF16x6: + result = dot(lhs_parts[mid], rhs_parts[mid], result, rewriter); + + result = dot(lhs_parts[lo], rhs_parts[hi], result, rewriter); + result = dot(lhs_parts[hi], rhs_parts[lo], result, rewriter); + + case InputPrecision::BF16x3: + result = dot(lhs_parts[mid], rhs_parts[hi], result, rewriter); + result = dot(lhs_parts[hi], rhs_parts[mid], result, rewriter); + result = replaceNansWithZeros(result, rewriter); + + // NOTE: For BF16x1 bail without replaceNansWithZeros + // case InputPrecision::BF16x1: break; + } + + result = dot(lhs_parts[hi], rhs_parts[hi], result, rewriter); + result = + arith::AddFOp::create(rewriter, dotOp.getLoc(), result, dotOp.getC()); + + rewriter.replaceOp(dotOp, result); + return success(); + } +}; + +// nb. We call the trick TF32x3 as C++ disallows variables starting with numbers +// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 +// For a, b f32 +// dot(a, b, inputPrecision="tf32x3") -> +// let aBig = f32ToTF32(a), aSmall = a - aBig; +// let bBig = f32ToTF32(b), bSmall = b - bBig; +// let small = dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") +// let masked_nans = replaceNansWithZeros(small) +// let big = dot(aBig, bBig, inputPrecision="tf32") +// return big + masked_nans; +class TF32x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 && + isF32(dotOp.getA()) && isF32(dotOp.getB()))) { + return failure(); + } + + // Aux functions + auto f32ToTF32 = [&](Value value) -> Value { + if (isMusaTarget(dotOp)) { + return ExternElementwiseOp::create( + rewriter, dotOp.getLoc(), value.getType(), + ArrayRef{value}, + /*libname=*/"", /*libpath=*/"", + /*symbol=*/"llvm.musa.float.to.tf32.rna", + /*pure=*/true) + .getResult(); + } + return ElementwiseInlineAsmOp::create( + rewriter, dotOp.getLoc(), value.getType(), + "cvt.rna.tf32.f32 $0, $1;", "=r,r", + /*isPure=*/true, /*pack=*/1, ArrayRef{value}) + .getResult()[0]; + }; + auto add = [&](Value a, Value b) -> Value { + return arith::AddFOp::create(rewriter, dotOp.getLoc(), a, b); + }; + auto sub = [&](Value a, Value b) -> Value { + return arith::SubFOp::create(rewriter, dotOp.getLoc(), a, b); + }; + + auto aBig = f32ToTF32(dotOp.getA()); + auto aSmall = sub(dotOp.getA(), aBig); + + auto bBig = f32ToTF32(dotOp.getB()); + auto bSmall = sub(dotOp.getB(), bBig); + + auto zero = zeroLike(dotOp.getC(), rewriter); + + auto dot1 = dot(aSmall, bBig, zero, rewriter, InputPrecision::TF32, + dotOp.getMaxNumImpreciseAcc()); + auto dot2 = dot(aBig, bSmall, dot1, rewriter, InputPrecision::TF32, + dotOp.getMaxNumImpreciseAcc()); + + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, + // we must override any accumulated result if the last partial product is + // non-finite. + auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter); + auto dot3 = dot(aBig, bBig, dot2withZeroedNans, rewriter, + InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); + + auto sum = add(dot3, dotOp.getC()); + + rewriter.replaceOp(dotOp, sum); + return success(); + } +}; + +} // anonymous namespace + +struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { + using impl::TritonGPUF32DotTCBase::TritonGPUF32DotTCBase; + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + if (this->emuTF32) { + decomposePatterns.add(context); + } + decomposePatterns.add(context); + if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp new file mode 100644 index 0000000000..96e3752c6e --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -0,0 +1,1222 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This attribute is set by the front-end to control whether fusion is on. +static constexpr llvm::StringLiteral kFlattenAttr = "tt.flatten"; +// This attribute indicates the inner loop length has been speculated. +static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; +// This attribute is just used for testing the pass. +static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse"; + +namespace { +struct FuseNestedLoopsPass + : public impl::TritonGPUFuseNestedLoopsBase { + using TritonGPUFuseNestedLoopsBase::TritonGPUFuseNestedLoopsBase; + + void runOnOperation() override; +}; + +//===----------------------------------------------------------------------===// +// LoopNest +//===----------------------------------------------------------------------===// + +// A node in the loop nest represents a single for loop with a list of +// immediately nested loops. +struct LoopNestNode { + LoopNestNode(scf::ForOp loop) : loop(loop) {} + + // The for loop. + scf::ForOp loop; + // Loops nested immediately below this loop. + SmallVector children; +}; + +// A loop nest is a tree of loops. +struct LoopNest { + LoopNest(scf::ForOp outermost); + + // Print the loop nest. + void print(raw_ostream &os) const; + // Dump the loop nest for debugging. + LLVM_DUMP_METHOD void dump() const; + + // Owner of the memory of the nodes. + SmallVector> nodes; + + // The outermost loop in the nest, which has no preconditions. Even if the + // outermost loop is contained within an if, its preconditions relative to the + // loop nest are empty. + LoopNestNode *root; +}; +} // namespace + +LoopNest::LoopNest(scf::ForOp outermost) + : root( + nodes.emplace_back(std::make_unique(outermost)).get()) { +} + +void LoopNest::print(raw_ostream &os) const { + // Print just the first line of the loop's textual IR. + std::string buffer; + auto printLoopFirstLine = [&](scf::ForOp loop) { + buffer.clear(); + llvm::raw_string_ostream str(buffer); + loop.print(str); + os << buffer.substr(0, buffer.find('\n')); + }; + + os << "LoopNest:\n"; + SmallVector> stack; + stack.emplace_back(root, 0); + while (!stack.empty()) { + auto [node, indent] = stack.pop_back_val(); + + // Print the current loop. + os << std::string(indent * 2, ' '); + printLoopFirstLine(node->loop); + os << "\n"; + + // Push the children of the current loop. + for (LoopNestNode *child : node->children) + stack.emplace_back(child, indent + 1); + } + os << "\n"; +} + +void LoopNest::dump() const { print(llvm::dbgs()); } + +//===----------------------------------------------------------------------===// +// findLoopNests +//===----------------------------------------------------------------------===// + +// Forward declaration. +static void findLoopNests(Operation *container, + SmallVectorImpl &nests); + +// Recursively construct a loop nest. +static void constructLoopNest(LoopNestNode *parent, LoopNest &nest, + SmallVectorImpl &nests) { + parent->loop->walk([&](Operation *op) { + if (op == parent->loop) + return WalkResult::advance(); + + if (auto forOp = dyn_cast(op)) { + auto &child = + nest.nodes.emplace_back(std::make_unique(forOp)); + parent->children.push_back(child.get()); + // Recurse with the current loop nest. + constructLoopNest(child.get(), nest, nests); + return WalkResult::skip(); + } + + // If the traversal encounters any other operation with regions, restart the + // traversal and construct new loop nests. This means ops like `scf.while` + // divide the analysis domain, but it also means loop fusion won't "see" + // across `scf.if`, for example. + // TODO: Handle loop nests with preconditions. The traversal can keep a + // stack of `scf.if` preconditions while constructing the loop nest. + if (op->getNumRegions()) { + findLoopNests(op, nests); + return WalkResult::skip(); + } + + return WalkResult::advance(); + }); +} + +// Find all the loop nests in the operation. The only region operation that +// allows CFG regions is `tt.func`. That means we can just walk starting from +// the function body and can build loop nests directly off the region trees +// contained in the function -- we don't have to worry about CFGs inside the +// nested region trees. +static void findLoopNests(Operation *container, + SmallVectorImpl &nests) { + container->walk([&](scf::ForOp loop) { + LoopNest nest(loop); + constructLoopNest(nest.root, nest, nests); + nests.push_back(std::move(nest)); + return WalkResult::skip(); + }); +} + +//===----------------------------------------------------------------------===// +// Logue +//===----------------------------------------------------------------------===// + +namespace { +// A prologue or epilogue. +struct Logue { + // Move the ops in the logue before the iterator. + void moveBefore(Block *block, Block::iterator it) { + for (Operation *op : ops) + op->moveBefore(block, it); + } + + // Replace all uses of the logue results with the given values, where `logue` + // comprises all the ops in `containingRegion`. + void replaceAllUsesWith(ValueRange values, Region &containingRegion) { + for (auto [newOut, output] : llvm::zip(values, outputs)) { + // Replace uses of the prologue outputs that are not in the prologue, i.e. + // inside the `then` region where it got spliced. + output.replaceUsesWithIf(newOut, [&](OpOperand &use) { + return !containingRegion.isAncestor(use.getOwner()->getParentRegion()); + }); + } + } + + // Get the number of outputs. + unsigned getNumOutputs() const { return outputs.size(); } + // Get the outputs as a `ValueRange`. + ValueRange getOutputs() const { return outputs; } + // Get the types of the outputs. + TypeRange getOutputTypes() const { return getOutputs().getTypes(); } + + // A contiguous range of ops representing the prologue or epilogue. + SmallVector ops; + // The outputs of the logue. These are the SSA value results of `ops` that are + // used by ops outside of `ops`. + SmallVector outputs; +}; +} // namespace + +// Given a range of ops, form it into a logue by finding the outputs. +static Logue createLogueFrom(llvm::iterator_range ops, + mlir::DominanceInfo &domInfo) { + Logue logue; + for (Operation &op : ops) + logue.ops.push_back(&op); + + if (ops.empty()) + return logue; + + // An op result is an output of the logue if the last operation in the logue + // dominates any of its users. + Operation &lastOp = *std::prev(ops.end()); + auto isOutput = [&](OpResult result) { + for (Operation *user : result.getUsers()) { + if (domInfo.properlyDominates(&lastOp, user)) + return true; + } + return false; + }; + + // Find the outputs. + for (Operation &op : ops) { + for (OpResult result : op.getOpResults()) { + if (isOutput(result)) + logue.outputs.push_back(result); + } + } + + return logue; +} + +//===----------------------------------------------------------------------===// +// fuseOneLevel +//===----------------------------------------------------------------------===// + +// Only hoist operations that are side-effect free and "cheap" (i.e. only scalar +// operands). Importantly, we need to be able to hoist code generated by fusing +// children loops into their parents so the algorithm can be applied +// recursively. This includes integer division, which are not speculatable, but +// we know they will never divide by zero. +static bool canHoistLoopBoundComputation(Operation *op) { + auto isScalar = [](Type type) { + return type.isIntOrIndexOrFloat() || isa(type); + }; + return (isMemoryEffectFree(op) || hasSingleEffect(op)) && + llvm::all_of(op->getOperandTypes(), isScalar) && + llvm::all_of(op->getResultTypes(), isScalar); +} + +// Determine if all of `values` are or can be made invariant to the outer loop +// by hoisting operations. `toHoist` is shared across all child loop bounds. +static bool isOuterLoopInvariant(mlir::DominanceInfo &domInfo, scf::ForOp outer, + ArrayRef values, + llvm::SetVector &toHoist) { + return getDominatingValueSetOpsToHoist( + domInfo, outer, values, toHoist, canHoistLoopBoundComputation, + [&](BlockArgument arg) { + return isa(arg.getOwner()->getParentOp()); + }); +} + +static bool canSliceBounds(mlir::DominanceInfo &domInfo, scf::ForOp outer, + ArrayRef values, + llvm::SetVector &ops) { + return getDominatingValueSetOpsToHoist( + domInfo, outer, values, ops, canHoistLoopBoundComputation, + [&](BlockArgument arg) { + return arg == outer.getInductionVar() || + isa(arg.getOwner()->getParentOp()); + }); +} + +// Pessimistically assume the internal storage bitwidth for index types. +static unsigned getIntTypeWidth(Type type) { + if (isa(type)) + return IndexType::kInternalStorageBitWidth; + return cast(type).getWidth(); +} + +// Generate IR to compute the number of iterations of a loop. +static Value computeNumIters(ImplicitLocOpBuilder &b, Value lowerBound, + Value upperBound, Value step) { + // len(range(lb, ub, step)) = ceildiv(ub - lb, step) + // This works even if step is negative. + Value diff = arith::SubIOp::create(b, upperBound, lowerBound); + // Let someone else prove it can be unsigned. + return arith::CeilDivSIOp::create(b, diff, step); +} + +// Generate IR to compute the number of iterations of a loop. +static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) { + return computeNumIters(b, loop.getLowerBound(), loop.getUpperBound(), + loop.getStep()); +} + +// Cast an integer or index value to an integer or index `type`, if necessary. +static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, + Type type) { + if (value.getType() == type) + return value; + if (isa(value.getType()) || isa(type)) + return arith::IndexCastOp::create(b, type, value); + if (cast(value.getType()).getWidth() > + cast(type).getWidth()) + return arith::TruncIOp::create(b, type, value); + return arith::ExtSIOp::create(b, type, value); +} + +// To model an "undef" value, i.e. a value that is known to never be read on +// live code paths, create a zero-valued constant where possible, otherwise use +// a poison value. PTXAS appears to generate better code with zeros compared to +// poison values. +static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { + Type elTy = getElementTypeOrSelf(type); + if (!elTy.isIntOrIndexOrFloat() || + (!isa(type) && type != elTy)) + return ub::PoisonOp::create(b, type); + + TypedAttr attr = isa(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0)) + : b.getIntegerAttr(elTy, 0); + if (auto tensor = dyn_cast(type)) + attr = SplatElementsAttr::get(tensor, attr); + return arith::ConstantOp::create(b, attr); +} + +static scf::YieldOp getYield(Region &body) { + return cast(body.front().back()); +} + +static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp, + llvm::BitVector indices, + SmallVector replaceWith) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(ifOp); + while (indices.size() < ifOp.getNumResults()) + indices.push_back(false); + + getYield(ifOp.getThenRegion())->eraseOperands(indices); + getYield(ifOp.getElseRegion())->eraseOperands(indices); + + TypeRange newTypes = getYield(ifOp.getThenRegion()).getOperandTypes(); + auto newIf = scf::IfOp::create(b, newTypes, ifOp.getCondition()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + SmallVector replacements; + auto replIt = replaceWith.begin(); + auto resIt = newIf->result_begin(); + for (unsigned i : llvm::seq(ifOp.getNumResults())) + replacements.push_back(indices[i] ? *replIt++ : *resIt++); + assert(ValueRange(replacements).getTypes() == ifOp.getResultTypes()); + ifOp.replaceAllUsesWith(replacements); + ifOp.erase(); + return newIf; +} + +namespace { +struct InnerLoop { + InnerLoop(scf::ForOp op, llvm::SetVector slicedOps) + : op(op), slicedOps(std::move(slicedOps)) {} + + // Return true if the loop bounds are outer loop invariant. + bool isOuterLoopInvariant() const { return slicedOps.empty(); } + + // The actual loop op. + scf::ForOp op; + // Ops that must be sliced to compute the loop bounds + llvm::SetVector slicedOps; +}; +} // namespace + +// Given a one level loop nest in the form +// +// for i in range(lbi, ubi, stepi): +// prologue0(i) +// for j0 in range(lbj0, ubj0, stepj0): +// body0(i, j0) +// epilogue1(i) +// for j1 in range(lbj1, ubj1, stepj1): +// body1(i, j1) +// epilogue2(i) +// ... +// for jN in range(lbjN, ubjN, stepjN): +// bodyN(i, jN) +// epilogue(i) +// +// Rewrite this into a single loop in the form: +// +// len_i = len(range(lbi, ubi, stepi)) +// len_j0 = len(range(lbj0, ubj0, stepj0)) +// len_j1 = len(range(lbj1, ubj1, stepj1)) +// ... +// len_jN = len(range(lbjN, ubjN, stepjN)) +// inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N +// total_iters = len_i * inner_len +// +// T = 0 +// i = lbi - stepi +// for _ in range(total_iters): +// if T == 0: +// i += stepi +// prologue0(i) +// j0 = lbj0 +// if T >= 0 and T < len_j0: +// body0(i, j0) +// j0 += stepj0 +// +// if T == max(1, len_j0) - 1: +// prologue1(i) +// j1 = lbj1 +// if T >= max(1, len_j0) - 1 +// and T < max(1, len_j0) - 1 + len_j1: +// body1(i, j1) +// j1 += stepj1 +// +// if T == max(1, len_j0) + max(1, len_j1) - 2: +// prologue2(i) +// j2 = lbj2 +// if T >= max(1, len_j0) + max(1, len_j1) - 2 +// and T < max(1, len_j0) + max(1, len_j1) - 2 + len_j2: +// body2(i, j2) +// j2 += stepj2 +// +// ... +// +// if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N: +// prologueN(i) +// jN = lbjN +// if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N +// and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N + +// len_jN: +// bodyN(i, jN) +// jN += stepjN +// +// if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - (N + 1): +// epilogue(i) +// T = 0 if T == (inner_len - 1) else T + 1 +// +// This routine can be applied recursively on a loop nest tree, leaf-to-root, to +// flatten the loop nest into a single loop. However, this routine only fuses +// child loops whose loop bounds are invariant to the parent loop. For child +// loops where this is not the case, the function will ignore them. +// +// We could fuse loops with parent-loop-variant or even data-dependent bounds, +// but this will require generating `scf.while` in a form that is not friendly +// to the pipeliner. In order to effectively fuse and pipeline these kinds of +// loop nests, loop nest fusion and the pipeliner need to share a higher-level +// representation (or perhaps be the same pass). +// +// Note that there are many potential forms of the fused loop. This routine will +// attempt to minimize the number of fused loop iterations by overlapping the +// iteration spaces of the child loops and the epilogues. E.g. the last +// iteration of bodyjK will execute on the same fused loop iteration as +// epilogueK and the first iteration of bodyj(K+1). Hence the `- N` term in the +// total number of iterations. +// +// What the above Python-pseudo-code glosses over is SSA dependency management. +// To interpret the pseudocode as SSA IR, just imagine everything is put back +// into allocas and SSA formation re-runs after fusion, which one should note +// will introduce undefs. +// +// Handling dependencies will require turning implicit captures into +// loop-carried dependencies. Consider: +// +// scf.for %i = %lbi to %ubi step %stepi { +// %a = tt.call @func(%i) +// scf.for %j = %lbj to %ubj step %stepj { +// %b = tt.call @use(%a, %j) +// } +// } +// +// This needs to be rewritten into: +// +// %poison = ub.poison +// %Tlast, %ilast, %jlast, %alast = scf.for %unused = ... +// iter_args(%Tprev = %c-1_i32, +// %iprev = %lbi - %stepi, +// %jprev = %poison, +// %aprev = %poison) -> (i32, i32, i32, i32) { +// %T = (%Tprev + 1) mod (...) +// %a, %i, %j = scf.if %T == 0 { +// %inext = %iprev + 1 +// %jnext = %lbj - %stepj +// +// %anext = tt.call @func(%i) +// yield %inext, %jnext, %anext +// } else { +// yield %iprev, %jprev, %aprev +// } +// +// scf.if %T >= 0 and %T < ... { +// tt.call @use(%a, %j) +// } +// +// Note: the induction variables will be initialized to their lower bound to +// avoid underflow in lbjk - stepjk, with the exception of the outer loop +// induction variable, which needs to be incremented inside the prologue to +// avoid a dependency on the epilogue. This helps the scheduler behave. +// +// Any inputs and outputs of the loop bodies would also need to be handled +// similarly: initialized as undef if appropriate and carried through the fused +// loop. This is why fusion will increase liveranges. To minimize the number of +// additional loop-carried values, the routine will analyze the subblock of IR +// inside each `prologueK` and determine its "outputs" as intermediate SSA +// values that are used later in the loop nest. +static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { + scf::ForOp outer = parent->loop; + + SmallVector innerLoops; + llvm::SetVector toHoist; + for (LoopNestNode *child : parent->children) { + scf::ForOp inner = child->loop; + assert(child->children.empty() && "fuseOneLevel runs leaf-to-root"); + + // Check if the inner loop bounds are or can be made invariant to the outer + // loop. Check them all at once to avoid adding ops to `toHoist` if not + // necessary. + if (isOuterLoopInvariant( + domInfo, outer, + {inner.getLowerBound(), inner.getUpperBound(), inner.getStep()}, + toHoist)) { + // Add this child to the list of loops to fuse. + innerLoops.push_back({child->loop, {}}); + continue; + } + + // Check if the loop bounds can be sliced. + llvm::SetVector slicedOps; + if (canSliceBounds( + domInfo, outer, + {inner.getLowerBound(), inner.getUpperBound(), inner.getStep()}, + slicedOps)) { + innerLoops.push_back({child->loop, std::move(slicedOps)}); + continue; + } + } + + // From the perspective of the overall analysis, we can delete all the + // children of the current loop node. Child loops that cannot be fused are now + // treated opaquely by the rest of the analysis. This allows partial fusing of + // the constructed loop nest. + parent->children.clear(); + + // If there are no child loops to fuse, then there is nothing to do. + if (innerLoops.empty()) + return; + + // The transformation will definitely succeed on `childrenToFuse`. `toHoist` + // only contains the operations that must be hoisted for `childrenToFuse` to + // be fusible. + hoistOpsBefore(outer, toHoist); + + // Determine the integer type to use for the length computations. Use an + // integer bitwidth twice the size of the largest integer, up to 64 bits, to + // avoid overflow. + unsigned intTyWidth = getIntTypeWidth(outer.getInductionVar().getType()); + + // Generate the computations of the fused loop bounds. + Location loc = outer.getLoc(); + ImplicitLocOpBuilder b(loc, outer); + for (InnerLoop &loop : innerLoops) { + intTyWidth = std::max(intTyWidth, + getIntTypeWidth(loop.op.getInductionVar().getType())); + } + auto intTy = b.getIntegerType(intTyWidth); + bool allInvariant = llvm::all_of( + innerLoops, [](InnerLoop &loop) { return loop.isOuterLoopInvariant(); }); + + Value lenOuter = computeNumIters(b, outer); + SmallVector lenInners; + for (InnerLoop &loop : innerLoops) { + // len_jk = len(range(lbjk, ubjk, stepjk)) + Value lenInner; + if (loop.isOuterLoopInvariant()) + lenInner = castIntIfNecessary(b, computeNumIters(b, loop.op), intTy); + else + lenInner = createPoisonOrZero(b, intTy); + lenInners.push_back(lenInner); + } + + auto intTyCst = [&](int64_t v) { + return arith::ConstantOp::create(b, IntegerAttr::get(intTy, v)); + }; + + // inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N + unsigned N = innerLoops.size() - 1; + Value innerLen = intTyCst(0); + for (auto [loop, lenInner] : llvm::zip(innerLoops, lenInners)) { + if (!loop.isOuterLoopInvariant()) + continue; + innerLen = arith::AddIOp::create( + b, innerLen, arith::MaxSIOp::create(b, intTyCst(1), lenInner)); + } + innerLen = arith::SubIOp::create(b, innerLen, intTyCst(N)); + + // total_iters = len_i * inner_len + Value totalIters = arith::MulIOp::create( + b, castIntIfNecessary(b, lenOuter, intTy), innerLen); + + // Generate a loop to compute the total number of iterations for inner loops + // whose bounds are not outer loop invariant. + IRMapping mapping; + auto peeledLen = + scf::ForOp::create(b, outer.getLowerBound(), outer.getUpperBound(), + outer.getStep(), {totalIters}); + totalIters = peeledLen.getRegionIterArg(0); + mapping.map(outer.getInductionVar(), peeledLen.getInductionVar()); + b.setInsertionPointToStart(peeledLen.getBody()); + for (InnerLoop &loop : innerLoops) { + if (loop.isOuterLoopInvariant()) + continue; + // Cloned the sliced ops into the peeled loop. + for (Operation *op : topologicalSort(loop.slicedOps)) { + if (!mapping.contains(op)) + b.clone(*op, mapping); + } + Value numIters = + computeNumIters(b, mapping.lookupOrDefault(loop.op.getLowerBound()), + mapping.lookupOrDefault(loop.op.getUpperBound()), + mapping.lookupOrDefault(loop.op.getStep())); + numIters = castIntIfNecessary(b, numIters, intTy); + // Accumulate into the total number of iterations. + numIters = arith::MaxSIOp::create(b, intTyCst(1), numIters); + totalIters = arith::AddIOp::create(b, totalIters, numIters); + } + scf::YieldOp::create(b, totalIters); + totalIters = peeledLen.getResults().front(); + b.setInsertionPointAfter(peeledLen); + + // The outputs of the prologue, each epilogue, and all inner loop bodies need + // to carried through the fused loop. + SmallVector logues; + auto addLogue = [&](Block::iterator begin, Block::iterator end) { + logues.push_back(createLogueFrom({begin, end}, domInfo)); + }; + // prologue0 + addLogue(outer.getBody()->begin(), innerLoops.front().op->getIterator()); + // prologuek where 0 < k <= N + for (auto i : llvm::seq(0, innerLoops.size() - 1)) { + addLogue(std::next(innerLoops[i].op->getIterator()), + innerLoops[i + 1].op->getIterator()); + } + // epilogue + addLogue(std::next(innerLoops.back().op->getIterator()), + // Don't include the outer loop yield. + std::prev(outer.getBody()->end())); + + // We need iter args for: + // - The fused loop induction var + // - The outer loop induction var + // - The outer loop iter args + // - The induction vars for each inner loop + // - The outputs of each child loop + // - The outputs of each logue + SmallVector fusedInits; + + // T = 0 + fusedInits.push_back(intTyCst(0)); + // i = lbi - stepi + fusedInits.push_back( + arith::SubIOp::create(b, outer.getLowerBound(), outer.getStep())); + + unsigned outerArgsStartIdx = fusedInits.size(); + llvm::append_range(fusedInits, outer.getInits()); + unsigned lenInnersStartIdx = fusedInits.size(); + llvm::append_range(fusedInits, lenInners); + unsigned innerLenStartIdx = fusedInits.size(); + fusedInits.push_back(innerLen); + + // Everything else is initialized to undef. + unsigned ivarStartIdx = fusedInits.size(); + for (InnerLoop &loop : innerLoops) { + fusedInits.push_back( + createPoisonOrZero(b, loop.op.getInductionVar().getType())); + } + unsigned innerOutsStartIdx = fusedInits.size(); + for (InnerLoop &loop : innerLoops) { + for (Type resultType : loop.op.getResultTypes()) + fusedInits.push_back(createPoisonOrZero(b, resultType)); + } + unsigned logueOutsStartIdx = fusedInits.size(); + for (Logue &logue : llvm::drop_end(logues)) { + for (Type outputType : logue.getOutputTypes()) + fusedInits.push_back(createPoisonOrZero(b, outputType)); + } + + // for _ in range(total_iters): + auto fused = + scf::ForOp::create(b, intTyCst(0), totalIters, intTyCst(1), fusedInits); + // Replace the outer loop args with the args in the fused loop args. + for (auto [arg, fusedArg] : + llvm::zip(outer.getRegionIterArgs(), + fused.getRegionIterArgs().slice(outerArgsStartIdx))) { + arg.replaceAllUsesWith(fusedArg); + } + ValueRange lenInnersRange = + fused.getRegionIterArgs().slice(lenInnersStartIdx, lenInners.size()); + for (auto [lenInner, lenInnerArg] : llvm::zip(lenInners, lenInnersRange)) + lenInner = lenInnerArg; + b.setInsertionPointToStart(fused.getBody()); + + Value T = fused.getRegionIterArg(0); + // `i` is computed inside the first prologue. + Value curI = fused.getRegionIterArg(1); + Value i; + + auto lenInnersIt = + ValueRange(fused.getRegionIterArgs()).begin() + lenInnersStartIdx; + + ArrayRef ivars = fused.getRegionIterArgs().slice(ivarStartIdx); + auto bodyOutsIt = + ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx; + auto logueOutsIt = + ValueRange(fused.getRegionIterArgs()).begin() + logueOutsStartIdx; + SmallVector prologueIfs, bodyIfs; + for (unsigned k = 0; k <= N; ++k) { + // if T == max(1, len_j0) + ... max(1, len_jk-1) - k + // [[if k == 0]] i += stepi + // prologuek(i) + // jk = lbjk + Value innerStartT = intTyCst(0); + for (unsigned i = 0; i < k; ++i) { + innerStartT = arith::AddIOp::create( + b, innerStartT, arith::MaxSIOp::create(b, intTyCst(1), lenInners[i])); + } + innerStartT = arith::SubIOp::create(b, innerStartT, intTyCst(k)); + Value prologueCond = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, T, innerStartT); + + // The `scf.if` outputs will be `jk` and the outputs of prologuek. We also + // have to initialize the inner loop iter args. + scf::ForOp inner = innerLoops[k].op; + Logue &prologue = logues[k]; + + SmallVector prologueOutTypes{inner.getInductionVar().getType()}; + llvm::append_range(prologueOutTypes, prologue.getOutputTypes()); + llvm::append_range(prologueOutTypes, inner.getInits().getTypes()); + if (k == 0) { + prologueOutTypes.push_back(curI.getType()); + prologueOutTypes.append(innerLoops.size(), intTy); + prologueOutTypes.push_back(innerLen.getType()); + } + auto prologueIf = scf::IfOp::create(b, prologueOutTypes, prologueCond); + prologueIfs.push_back(prologueIf); + + // Splice prologuek into the `then` region. + Block *thenBlock = b.createBlock(&prologueIf.getThenRegion()); + prologue.moveBefore(thenBlock, thenBlock->end()); + + if (k == 0) { + // Increment `i` and replace its uses inside the prologue. + b.setInsertionPointToStart(thenBlock); + i = arith::AddIOp::create(b, curI, outer.getStep()); + mlir::replaceAllUsesInRegionWith(outer.getInductionVar(), i, + prologueIf.getThenRegion()); + + // Compute the variant inner loop lengths. + IRMapping mapping; + for (auto [loop, lenInner] : llvm::zip(innerLoops, lenInners)) { + if (loop.isOuterLoopInvariant()) + continue; + for (Operation *op : topologicalSort(loop.slicedOps)) { + if (!mapping.contains(op)) + b.clone(*op, mapping); + } + lenInner = + computeNumIters(b, mapping.lookupOrDefault(loop.op.getLowerBound()), + mapping.lookupOrDefault(loop.op.getUpperBound()), + mapping.lookupOrDefault(loop.op.getStep())); + lenInner = castIntIfNecessary(b, lenInner, intTy); + innerLen = arith::AddIOp::create( + b, innerLen, arith::MaxSIOp::create(b, intTyCst(1), lenInner)); + } + } + + // Yield the initialized jk, the prologue outputs, and the initial values of + // the inner loop. + b.setInsertionPointToEnd(thenBlock); + SmallVector thenOuts{inner.getLowerBound()}; + llvm::append_range(thenOuts, prologue.getOutputs()); + llvm::append_range(thenOuts, inner.getInits()); + if (k == 0) { + thenOuts.push_back(i); + llvm::append_range(thenOuts, lenInners); + thenOuts.push_back(innerLen); + } + scf::YieldOp::create(b, thenOuts); + + // In the `else` region, just yield the last values of jk, the outputs, and + // the iter args. + b.createBlock(&prologueIf.getElseRegion()); + Value lastJk = ivars[k]; + unsigned numOuts = prologue.getNumOutputs(); + SmallVector elseOuts{lastJk}; + elseOuts.append(logueOutsIt, logueOutsIt + numOuts); + elseOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + if (k == 0) { + elseOuts.push_back(curI); + llvm::append_range(elseOuts, lenInnersRange); + // Peephole the passthrough of `innerLen` since MLIR will not optimize it + // away for us. + elseOuts.push_back( + allInvariant ? innerLen : fused.getRegionIterArg(innerLenStartIdx)); + } + logueOutsIt += numOuts; + scf::YieldOp::create(b, elseOuts); + + // The results of the `scf.if` become the values of jk and the prologue + // outputs for the rest of the fused loop. + Value jk = prologueIf.getResult(0); + ValueRange prologueOuts = prologueIf.getResults().slice(1, numOuts); + ValueRange prologueInits = + prologueIf.getResults().slice(1 + numOuts, inner.getNumResults()); + inner.getInductionVar().replaceAllUsesWith(jk); + prologue.replaceAllUsesWith(prologueOuts, prologueIf.getThenRegion()); + for (auto [init, iterArg] : + llvm::zip(prologueInits, inner.getRegionIterArgs())) + iterArg.replaceAllUsesWith(init); + // Replace uses of `i` elsewhere with the prologue result. + if (k == 0) { + ValueRange results = prologueIf.getResults(); + i = results.drop_back(1 + lenInners.size()).back(); + lenInners = results.drop_back().take_back(lenInners.size()); + innerLen = results.back(); + outer.getInductionVar().replaceAllUsesWith(i); + } + + // if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + // and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + + // len_jk + // bodyk(i, jk) + // jk += stepjk + b.setInsertionPointAfter(prologueIf); + Value innerEndT = arith::AddIOp::create( + b, innerStartT, castIntIfNecessary(b, lenInners[k], intTy)); + Value ge = + arith::CmpIOp::create(b, arith::CmpIPredicate::sge, T, innerStartT); + Value lt = + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, T, innerEndT); + Value bodyCond = arith::AndIOp::create(b, ge, lt); + + // The outputs will be the outputs of the inner loop body and the next jk. + SmallVector bodyOutTypes{jk.getType()}; + llvm::append_range(bodyOutTypes, inner->getResultTypes()); + auto bodyIf = scf::IfOp::create(b, bodyOutTypes, bodyCond); + bodyIfs.push_back(bodyIf); + + // Splice bodyk into the `then` region. + inner.getBody()->eraseArguments([](Value arg) { return true; }); + bodyIf.getThenRegion().takeBody(inner.getBodyRegion()); + auto yield = getYield(bodyIf.getThenRegion()); + b.setInsertionPoint(yield); + Value nextJk = arith::AddIOp::create(b, jk, inner.getStep()); + yield->insertOperands(0, nextJk); + + // The `else` region just forwards the values. + b.createBlock(&bodyIf.getElseRegion()); + SmallVector bodyForwardedOuts{jk}; + bodyForwardedOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + bodyOutsIt += inner->getNumResults(); + scf::YieldOp::create(b, bodyForwardedOuts); + + // Now we can replace the results of the inner loop with the outputs of the + // body if. + inner.replaceAllUsesWith( + bodyIf.getResults().slice(1, inner.getNumResults())); + + // If the inner loop must execute, then its body does not have to be wrapped + // in a conditional. + if (inner->hasAttr(kMustExecuteAttrName)) { + b.setInsertionPoint(bodyIf); + bodyIf.getConditionMutable().assign( + arith::ConstantOp::create(b, b.getBoolAttr(true))); + } + + // Move the insertion point for the next iteration. + b.setInsertionPointAfter(bodyIf); + } + + // if T == len_j0 + len_j1 + ... + len_jN - N - 1: + // epilogue(i) + Logue &epilogue = logues.back(); + + // The only possible use of an epilogue output is the yield. + auto outerYield = cast(outer.getBody()->getTerminator()); + SmallVector usedIterArgs; + for (Value output : epilogue.getOutputs()) { + for (OpOperand &use : output.getUses()) { + if (use.getOwner() == outerYield) { + usedIterArgs.push_back(fused.getRegionIterArgs().drop_front( + outerArgsStartIdx)[use.getOperandNumber()]); + } + } + } + + auto epilogueCond = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, T, + arith::SubIOp::create(b, innerLen, intTyCst(1))); + auto epilogueIf = + scf::IfOp::create(b, epilogue.getOutputTypes(), epilogueCond); + + Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion()); + epilogue.moveBefore(thenBlock, thenBlock->end()); + + b.setInsertionPointToEnd(thenBlock); + scf::YieldOp::create(b, epilogue.getOutputs()); + b.createBlock(&epilogueIf.getElseRegion()); + scf::YieldOp::create(b, usedIterArgs); + epilogue.replaceAllUsesWith(epilogueIf.getResults(), + epilogueIf.getThenRegion()); + + // T = 0 if T == (inner_len - 1) else T + 1 + b.setInsertionPointToEnd(fused.getBody()); + Value nextT = arith::AddIOp::create(b, T, intTyCst(1)); + Value rollover = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, T, + arith::SubIOp::create(b, innerLen, intTyCst(1))); + T = arith::SelectOp::create(b, rollover, intTyCst(0), nextT); + + // Finally, create the yield of the fused loop. + SmallVector outerOuts{T, i}; + llvm::append_range(outerOuts, outerYield.getOperands()); + llvm::append_range(outerOuts, lenInners); + outerOuts.push_back(innerLen); + for (scf::IfOp bodyIf : bodyIfs) + outerOuts.push_back(/*jk=*/bodyIf.getResult(0)); + for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) { + llvm::append_range(outerOuts, + bodyIf.getResults().slice(1, loop.op.getNumResults())); + } + for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) { + llvm::append_range(outerOuts, + logueIf.getResults().slice(1, logue.getNumOutputs())); + } + + scf::YieldOp::create(b, outerOuts); + outer.replaceAllUsesWith( + fused.getResults().slice(outerArgsStartIdx, outer.getNumResults())); + + // Reduce dependencies across inner loops by hoisting the initialization of + // inner loop iter args to the outer loop when possible, and then placing the + // reset of these values in the epilogue. + auto fusedInitsIt = fused.getInitsMutable().begin() + innerOutsStartIdx; + auto fusedArgsIt = fused.getRegionIterArgs().begin() + innerOutsStartIdx; + auto fusedYieldIt = getYield(fused.getBodyRegion())->getOpOperands().begin() + + innerOutsStartIdx; + SmallVector yieldsToUpdate; + SmallVector reset, forwarded; + for (auto [loop, ifOp, bodyIf, prologue] : + llvm::zip(innerLoops, prologueIfs, bodyIfs, logues)) { + unsigned numResults = loop.op.getNumResults(); + unsigned prologueSkip = 1 + prologue.getNumOutputs(); + + llvm::BitVector removeIndices(prologueSkip + numResults); + SmallVector replaceWith; + for (auto [i, init] : llvm::enumerate(loop.op.getInits())) { + if (init.getParentRegion() == &fused.getBodyRegion()) + continue; + // Initialize this in the outer loop. + fusedInitsIt[i].assign(init); + replaceWith.push_back(fusedArgsIt[i]); + removeIndices.set(prologueSkip + i); + yieldsToUpdate.push_back(&fusedYieldIt[i]); + forwarded.push_back(bodyIf.getResult(1 + i)); + reset.push_back(init); + } + // Remove the initializers in the corresponding prologue. + eraseIfResults(b, ifOp, removeIndices, replaceWith); + + fusedInitsIt += numResults; + fusedArgsIt += numResults; + fusedYieldIt += numResults; + } + if (!yieldsToUpdate.empty()) { + MutableOperandRange(getYield(epilogueIf.getThenRegion())).append(reset); + MutableOperandRange(getYield(epilogueIf.getElseRegion())).append(forwarded); + b.setInsertionPoint(epilogueIf); + TypeRange newTypes = getYield(epilogueIf.getThenRegion()).getOperandTypes(); + auto newIf = scf::IfOp::create(b, newTypes, epilogueIf.getCondition()); + newIf.getThenRegion().takeBody(epilogueIf.getThenRegion()); + newIf.getElseRegion().takeBody(epilogueIf.getElseRegion()); + epilogueIf.replaceAllUsesWith( + newIf.getResults().take_front(epilogueIf.getNumResults())); + ResultRange newResults = + newIf.getResults().drop_front(epilogueIf.getNumResults()); + for (auto [i, yieldOperand] : llvm::enumerate(yieldsToUpdate)) + yieldOperand->set(newResults[i]); + epilogueIf.erase(); + } + + // Propagate warp specialization flags. + if (outer->hasAttr(kWarpSpecializeAttrName) || + llvm::any_of(innerLoops, [](InnerLoop &loop) { + return loop.op->hasAttr(kWarpSpecializeAttrName); + })) + fused->setAttr(kWarpSpecializeAttrName, b.getUnitAttr()); + + // Propagate the `tt.disallow_acc_multi_buffer` attribute to the parent loop. + bool disallowAccMultiBuffer = getDisallowAccMultiBuffer(outer); + for (InnerLoop &loop : innerLoops) { + disallowAccMultiBuffer |= getDisallowAccMultiBuffer(loop.op); + } + if (disallowAccMultiBuffer) + fused->setAttr(kDisallowAccMultiBufferAttrName, b.getUnitAttr()); + + // Update the parent's loop to the fused loop. Set the new stage count to the + // max stage count of the inner loops. + int numStages = 1; + if (auto stageAttr = outer->getAttrOfType(kNumStagesAttrName)) + numStages = stageAttr.getInt(); + for (InnerLoop &loop : innerLoops) { + if (auto stageAttr = + loop.op->getAttrOfType(kNumStagesAttrName)) + numStages = std::max(numStages, stageAttr.getInt()); + loop.op.erase(); + } + outer.erase(); + parent->loop = fused; + if (numStages > 1) + fused->setAttr(kNumStagesAttrName, b.getI32IntegerAttr(numStages)); +} + +//===----------------------------------------------------------------------===// +// flattenLoopNest +//===----------------------------------------------------------------------===// + +// Completely flatten a loop nest by recursively fusing loops in a post-order +// traversal with `fuseOneLevel`. +static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) { + for (LoopNestNode *child : node->children) + flattenLoopNest(child, domInfo); + fuseOneLevel(node, domInfo); +} + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +// Fuse simple loop nests with a single outer and inner loop, and where the +// inner loop has a `tt.dot` operation. +static bool shouldFuse(const LoopNest &nest) { + if (nest.root->loop->hasAttr(kAlwaysFuseAttrName)) + return true; + + // Only fuse simple loop nests. + return nest.nodes.size() == 2 && nest.root->children.size() == 1 && + nest.root->loop->hasAttr(kFlattenAttr); +} + +// This function identifies a subgraph of cheap ops that can be sunk between two +// regions in the loop nest and moves them, reducing their liveranges. +static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, + llvm::iterator_range prologue, + function_ref inSinkRegion) { + llvm::SetVector sunkOps; + auto canBeSunk = [&](Operation &op) -> std::pair { + if (!isPure(&op) || isa(op)) + return {false, false}; + // An op can be sunk if all its users are inside the inner loop or are + // marked for sinking. + bool isRoot = true; + for (Operation *user : op.getUsers()) { + if (inSinkRegion(user)) + continue; + isRoot = false; + if (sunkOps.contains(user)) + continue; + return {false, false}; + } + return {true, isRoot}; + }; + + // Find the subgraph of operations that can be sunk. + SmallVector roots; + for (Operation &op : llvm::reverse(prologue)) { + auto [canSink, isRoot] = canBeSunk(op); + if (canSink) + sunkOps.insert(&op); + if (isRoot) + roots.push_back(&op); + } + if (sunkOps.empty()) + return; + + hoistOpsBefore(sinkBlock, sinkBefore, sunkOps); +} + +// Sink ops from the prologue into the epilogue when possible. +static void optimizeEpilogueDependencies(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + auto inEpilogue = [&](Operation *op) { + return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); + }; + Region &limit = outerLoop.getBodyRegion(); + sinkOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), + {outerLoop.getBody()->begin(), innerLoop->getIterator()}, inEpilogue); +} + +// Crudely match llvm.assume(ub > lb) or llvm.assume(lb < ub). +static LogicalResult matchPositiveTripCount(scf::ForOp loop) { + for (Operation *user : loop.getUpperBound().getUsers()) { + if (auto cmp = dyn_cast(user)) { + if (llvm::none_of(cmp->getUsers(), + [](Operation *op) { return isa(op); })) + continue; + if (cmp.getPredicate() == (loop.getUnsignedCmp() + ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt) && + cmp.getLhs() == loop.getUpperBound() && + cmp.getRhs() == loop.getLowerBound()) + return success(); + if (cmp.getPredicate() == (loop.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt) && + cmp.getLhs() == loop.getLowerBound() && + cmp.getRhs() == loop.getUpperBound()) + return success(); + } + } + return failure(); +} + +// Speculate the length of the inner loop such that the loop is known to execute +// at least once. This way, the inner loop body does not have to be placed +// inside a conditional in the fused loop, which interacts better with the +// pipeliner. +static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + Location loc = innerLoop.getLoc(); + ImplicitLocOpBuilder b(loc, outerLoop); + + // Check if the inner loop is known to execute at least once. + if (succeeded(matchPositiveTripCount(innerLoop))) { + innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); + return success(); + } + + // The inner loop bounds must be outer-loop invariant to speculate from + // outside the loop nest. + llvm::SetVector toHoist; + if (!isOuterLoopInvariant(domInfo, outerLoop, + {innerLoop.getLowerBound(), + innerLoop.getUpperBound(), innerLoop.getStep()}, + toHoist)) + return failure(); + + // Hoist the inner loop bounds computations if necessary. + hoistOpsBefore(outerLoop, toHoist); + + // Mark the inner loop. + innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); + + // Speculate on whether the length of the inner loop is zero. + Value lenInner = computeNumIters(b, innerLoop); + auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0); + Value innerLoopEmpty = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, lenInner, + arith::ConstantOp::create(b, zeroAttr)); + auto ifOp = scf::IfOp::create(b, outerLoop.getResultTypes(), innerLoopEmpty); + + // In the `then` branch, the inner loop does not execute. Clone the loop nest + // into it and remove the inner loop. + mlir::IRMapping map; + b.createBlock(&ifOp.getThenRegion()); + auto newLoop = cast(b.clone(*outerLoop, map)); + scf::YieldOp::create(b, newLoop.getResults()); + auto newInnerLoop = cast(map.lookup(innerLoop)); + newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits()); + newInnerLoop.erase(); + + // Clear up the warp specialization attributes for the specialized loop. + newLoop->removeAttr(kWarpSpecializeAttrName); + + // Move the loop nest into the `else` branch. + outerLoop.replaceAllUsesWith(ifOp.getResults()); + Block *block = b.createBlock(&ifOp.getElseRegion()); + outerLoop->remove(); + b.insert(outerLoop); + scf::YieldOp::create(b, outerLoop.getResults()); + + return success(); +} + +static LogicalResult preprocessLoopNest(const LoopNest &nest, + mlir::DominanceInfo &domInfo) { + assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); + + scf::ForOp &outerLoop = nest.root->loop; + scf::ForOp &innerLoop = nest.root->children.front()->loop; + + moveLoopInvariantCode(outerLoop); + optimizeEpilogueDependencies(outerLoop, innerLoop, domInfo); + return speculateInnerLoopLength(outerLoop, innerLoop, domInfo); +} + +void FuseNestedLoopsPass::runOnOperation() { + auto &domInfo = getAnalysis(); + + for (auto func : getOperation().getOps()) { + SmallVector nests; + findLoopNests(func, nests); + for (LoopNest &nest : nests) { + if (!shouldFuse(nest)) + continue; + if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) && + failed(preprocessLoopNest(nest, domInfo))) + continue; + flattenLoopNest(nest.root, domInfo); + } + } +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp new file mode 100644 index 0000000000..86e5e2e774 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp @@ -0,0 +1,586 @@ +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUHOISTTMEMALLOC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// This CRTP class is an operation type constraint that checks that it has TMEM +// dependency tokens present. HoistTMEMAlloc requires that TMEM tokens are +// present to check aliasing for its transformations. +template struct HasToken : public OpT { + using OpT::OpT; + + static bool classof(Operation *op) { + if (auto tmemOp = dyn_cast(op)) + return !!tmemOp.getToken(); + return false; + } +}; + +using TMEMTokenLoadOp = HasToken; +using TMEMTokenStoreOp = HasToken; +using TMEMTokenAllocOp = HasToken; + +class CombineTMEMStoreAndSelect : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMStoreOp store, + PatternRewriter &rewriter) const override { + if (!store.getDep()) + return failure(); + Value src = store.getSrc(); + auto select = src.getDefiningOp(); + if (!select) { + return failure(); + } + enum { kTrue, kFalse, kUnknown } valueFromTMEM = kUnknown; + Value trueSrc = select.getTrueValue(); + Value falseSrc = select.getFalseValue(); + if (auto load = trueSrc.getDefiningOp()) { + if (store.getDst() == load.getSrc() && load.getToken() == store.getDep()) + valueFromTMEM = kTrue; + } + if (auto load = falseSrc.getDefiningOp()) { + if (store.getDst() == load.getSrc() && load.getToken() == store.getDep()) + valueFromTMEM = valueFromTMEM == kTrue ? kUnknown : kFalse; + } + if (valueFromTMEM == kUnknown) { + return failure(); + } + Value pred = select.getCondition(); + // In case the false operand is overwriting, we need to negate the predicate + // (owerwrite when select would be false) + if (valueFromTMEM == kTrue) { + Value one = arith::ConstantIntOp::create(rewriter, select.getLoc(), 1, 1); + pred = arith::XOrIOp::create(rewriter, select.getLoc(), pred, one); + } + // Store the selected value with the updated predicate + Value overwritingValue = valueFromTMEM == kTrue ? falseSrc : trueSrc; + rewriter.replaceOpWithNewOp( + store, rewriter.getType(), store.getDst(), + store.getDep(), overwritingValue, pred); + return success(); + } +}; + +class RemoveUnusedTMEMLoad : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMLoadOp load, + PatternRewriter &rewriter) const override { + if (!load.getDep()) + return failure(); + if (!load.getResult().use_empty()) + return failure(); + rewriter.replaceAllUsesWith(load.getToken(), load.getDep()); + return success(); + } +}; + +// Load-store forwarding pattern. +class CombineTMEMLoadAndStore : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMStoreOp store, + PatternRewriter &rewriter) const override { + if (!store.getDep()) + return failure(); + auto load = store.getDep().getDefiningOp>(); + if (!load || load.getResult() != store.getSrc() || + load.getSrc() != store.getDst()) + return failure(); + rewriter.replaceOp(store, load.getToken()); + return success(); + } +}; + +class SinkTMEMLoad : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMLoadOp load, + PatternRewriter &rewriter) const override { + if (!load.getDep()) + return failure(); + auto forOp = load->getParentOfType(); + if (!forOp) { + return failure(); + } + DominanceInfo domInfo(forOp); + Operation *domOp = findNearestCommonDominator( + llvm::to_vector(load.getResult().getUsers()), domInfo); + if (!domOp || !domInfo.properlyDominates(load.getOperation(), domOp)) { + return failure(); + } + // Don't sink past potentially aliasing ops. + PostDominanceInfo postDomInfo(forOp); + SmallVector uses; + for (OpOperand &use : load.getToken().getUses()) + uses.push_back(&use); + if (!llvm::all_of(uses, [&](OpOperand *use) { + return postDomInfo.properlyPostDominates(use->getOwner(), domOp); + })) + return failure(); + // In order to not re-ordering multiple tmem load in a loop, don't sink if + // all the ops between the load and the domOp are tmem loads. + Operation *nextNode = load->getNextNode(); + while (auto tmemLoad = dyn_cast(nextNode)) { + nextNode = tmemLoad->getNextNode(); + } + if (domOp == nextNode) { + // The load wasn't moved. + return failure(); + } + rewriter.moveOpBefore(load, domOp); + Value newToken = sinkValueRedefinition(rewriter, load.getDep(), + load.getToken(), domOp->getBlock()); + if (newToken != load.getToken()) { + for (OpOperand *use : uses) + use->set(newToken); + } + return success(); + } +}; + +// Combine back TMEM alloc and store. This is equivalent but gives us a more +// canonical form to do further optimizations. +class CombineTMEMStoreAndAlloc : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMStoreOp store, + PatternRewriter &rewriter) const override { + if (!store.getDep()) + return failure(); + if (!matchPattern(store.getPred(), m_One())) + return failure(); + auto alloc = store.getDep().getDefiningOp(); + if (!alloc) + return failure(); + if (store.getDst() != alloc.getResult()) + return failure(); + if (alloc->getBlock() != store->getBlock()) + return failure(); + if (auto srcDef = store.getSrc().getDefiningOp()) { + if (alloc->getBlock() == srcDef->getBlock() && + alloc->isBeforeInBlock(srcDef)) + return failure(); + } + alloc.getSrcMutable().assign(store.getSrc()); + rewriter.replaceOp(store, alloc.getToken()); + return success(); + } +}; + +// Hoists a tmem alloc outside an if op like this: +// %0 = scf.if { +// %1, %token0 = tmem.alloc %init +// ... +// %2 = tmem.load %1, %token1 +// scf.yield %2 +// } else { +// scf.yield %init +// } +// -> +// %a, %token0 = tmem.alloc %init +// %token2 = scf.if { +// +// ... +// scf.yield %token1 +// } else { +// scf.yield %token0 +// } +// %2 = tmem.load %a, %token2 +class HoistTMEMAllocOutOfIf : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc, + PatternRewriter &rewriter) const override { + if (!alloc.getToken()) + return failure(); + Value init = alloc.getSrc(); + if (!init) + return failure(); + auto ifOp = dyn_cast(alloc->getParentOp()); + if (!ifOp || !ifOp.elseBlock()) + return failure(); + auto thenOp = ifOp.thenBlock()->getTerminator(); + auto elseOp = ifOp.elseBlock()->getTerminator(); + SmallVector yieldArgs; + for (auto [thenOperand, elseOperand] : + llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) { + auto load = thenOperand.get().getDefiningOp(); + if (!load || load.getSrc() != alloc.getResult()) + continue; + if (elseOperand.get() != init) + continue; + yieldArgs.push_back(thenOperand.getOperandNumber()); + } + if (yieldArgs.empty()) + return failure(); + // Since init is used in the else terminator we know that it dominates the + // if op. + alloc->moveBefore(ifOp); + rewriter.setInsertionPointAfter(ifOp); + for (int argNo : yieldArgs) { + auto load = + cast(thenOp->getOperand(argNo).getDefiningOp()); + auto newLoad = cast(rewriter.clone(*load)); + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult()); + newLoad.getDepMutable().assign(ifOp->getResult(argNo)); + thenOp->setOperand(argNo, load.getToken()); + elseOp->setOperand(argNo, alloc.getToken()); + ifOp->getResult(argNo).setType(newLoad.getToken().getType()); + }); + } + return success(); + } +}; + +// Forward a TMEM load into the user allocation. +class TMEMLoadForwarding : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc, + PatternRewriter &rewriter) const override { + if (!alloc.getToken()) + return failure(); + Value init = alloc.getSrc(); + if (!init) + return failure(); + auto load = init.getDefiningOp(); + if (!load || !load->hasOneUse() || !load.getDep().hasOneUse()) + return failure(); + if (alloc.getType() != load.getSrc().getType()) + return failure(); + rewriter.replaceOp(alloc, {load.getSrc(), load.getDep()}); + return success(); + } +}; + +// Remove loop-carried tensor dependencies if they are fed immediately into a +// TMEM store by pulling the store into the previous iteration. +class RotateTMEMStoreInLoop : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMStoreOp store, + PatternRewriter &rewriter) const override { + if (!store.getDep()) + return failure(); + // Pattern match stores whose source comes from a loop region argument and + // whose predicate is loop-invariant. + scf::ForOp forOp = dyn_cast(store->getParentOp()); + if (!forOp || !forOp.isDefinedOutsideOfLoop(store.getPred()) || + !forOp.isDefinedOutsideOfLoop(store.getDst())) { + return failure(); + } + auto getAsLoopArg = [&](Value v) -> BlockArgument { + auto arg = dyn_cast(v); + if (arg && arg.getOwner() == forOp.getBody()) + return arg; + return {}; + }; + BlockArgument src = getAsLoopArg(store.getSrc()); + if (!src || !src.hasOneUse()) { + return failure(); + } + + // Check that rotating the store into the past won't violate any + // write-after-read dependencies. + BlockArgument storeTok = getAsLoopArg(store.getDep()); + if (!storeTok) + return failure(); + int tokArgNo = storeTok.getArgNumber() - 1; + + // Create two copies of the store: one before the loop, storing the initial + // value, and one before the yield, storing the value carried by the loop + // arg. + int argNo = src.getArgNumber() - 1; + Value initVal = forOp.getInitArgs()[argNo]; + rewriter.setInsertionPoint(forOp); + auto tokType = rewriter.getType(); + auto initStore = ttng::TMEMStoreOp::create( + rewriter, store.getLoc(), tokType, store.getDst(), + forOp.getInitArgs()[tokArgNo], initVal, store.getPred()); + forOp.getInitArgsMutable()[tokArgNo].assign(initStore.getToken()); + + auto yield = cast(forOp.getBody()->getTerminator()); + store.getToken().replaceAllUsesWith(forOp.getRegionIterArg(tokArgNo)); + rewriter.moveOpBefore(store, yield); + store.getDepMutable().assign(yield.getOperand(tokArgNo)); + yield.setOperand(tokArgNo, store.getToken()); + store.getSrcMutable().assign(yield.getOperand(argNo)); + + // Load from the tmem after the loop, and use it instead of the loop carried + // value. + rewriter.setInsertionPointAfter(forOp); + auto load = ttng::TMEMLoadOp::create( + rewriter, store.getLoc(), store.getSrc().getType(), tokType, + store.getDst(), forOp.getResult(tokArgNo)); + forOp->getResult(argNo).replaceAllUsesWith(load.getResult()); + // Loop carried value is no longer used, short-circuit it. + yield.setOperand(argNo, forOp.getRegionIterArg(argNo)); + return success(); + } +}; + +// Remove loop-carried tensor dependencies if they are the result of TMEM loads +// at the end of the loop by pushing the load into the next iteration. +class RotateTMEMLoadInLoop : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttng::TMEMLoadOp load, + PatternRewriter &rewriter) const override { + if (!load.getDep()) + return failure(); + // Pattern match loads whose results are only passed into the next iteration + // of a loop. + scf::ForOp forOp = dyn_cast(load->getParentOp()); + if (!forOp || !forOp.isDefinedOutsideOfLoop(load.getSrc()) || + !load.getResult().hasOneUse()) { + return failure(); + } + OpOperand &use = *load.getResult().use_begin(); + auto yield = dyn_cast(use.getOwner()); + if (!yield) + return failure(); + + // By rotating the load into the future, we are essentially merging the + // loop-carried tensor value into the same TMEM allocation as the load. + // Thus, they cannot be live at the same time. Check this by ensuring we + // won't clobber the memory. + + // 1. There are no aliasing stores between the load and the end of the loop. + if (!llvm::is_contained(load.getToken().getUsers(), yield)) + return failure(); + // 2. The TMEM variable is live into the loop with an undefined value. + int tokArgNo = load.getToken().use_begin()->getOperandNumber(); + Value initTok = forOp.getInitArgs()[tokArgNo]; + auto initAlloc = initTok.getDefiningOp(); + if (!initAlloc || initAlloc.getSrc()) + return failure(); + // TODO: 3. The live-in value of the TMEM variable is never read. + + // Create a store before the loop to write the initial value. + int argNo = use.getOperandNumber(); + Value initVal = forOp.getInitArgs()[argNo]; + rewriter.setInsertionPoint(forOp); + auto vTrue = arith::ConstantIntOp::create(rewriter, load.getLoc(), 1, 1); + auto tokType = rewriter.getType(); + auto initStore = ttng::TMEMStoreOp::create( + rewriter, load.getLoc(), tokType, load.getSrc(), initAlloc.getToken(), + initVal, vTrue); + forOp.getInitArgsMutable()[tokArgNo].assign(initStore.getToken()); + + // Move the load to the beginning of the loop to load the tensor value. + yield.setOperand(tokArgNo, load.getDep()); + rewriter.moveOpBefore(load, &forOp.getBody()->front()); + Value tokArg = forOp.getRegionIterArg(tokArgNo); + load.getDepMutable().assign(tokArg); + tokArg.replaceAllUsesExcept(load.getToken(), load); + forOp.getRegionIterArg(argNo).replaceAllUsesWith(load.getResult()); + + // Load from the tmem after the loop, and use it instead of the loop carried + // value. + rewriter.setInsertionPointAfter(forOp); + auto loadAfterLoop = ttng::TMEMLoadOp::create( + rewriter, load.getLoc(), load.getResult().getType(), tokType, + load.getSrc(), forOp.getResult(tokArgNo)); + forOp->getResult(argNo).replaceAllUsesWith(loadAfterLoop.getResult()); + // Loop carried value is no longer used, short-circuit it. + yield.setOperand(argNo, forOp.getRegionIterArg(argNo)); + return success(); + } +}; + +// Given an operation that uses a token, return its forwarded token. This +// assumes the memory variable is not loop carried. +static Value getTokenFromOp(Operation *op) { + if (auto mmaOp = dyn_cast>(op)) { + return mmaOp.getToken(); + } else if (auto loadOp = dyn_cast(op)) { + return loadOp.getToken(); + } else if (auto storeOp = dyn_cast(op)) { + return storeOp.getToken(); + } + assert(!isa(op) && "unexpected loop carried token"); + llvm_unreachable("unknown TMEM memory user"); +} + +// Find all the last uses of a memory variable in a loop body. This traces the +// token lattice to its leaves. +static void findLastMemoryUses(OpResult token, + SmallVectorImpl &lastUses, + DenseSet &seen) { + if (!seen.insert(token).second) + return; + if (token.use_empty()) { + lastUses.push_back(token); + return; + } + for (Operation *user : token.getUsers()) + findLastMemoryUses(cast(getTokenFromOp(user)), lastUses, seen); +} + +// Find the last uses of a memory variable, joining them into a single token if +// necessary. This token can be carried into the next loop iteration. +static Value joinLastMemoryUses(OpBuilder &b, Value token) { + SmallVector lastUses; + DenseSet seenTokens; + findLastMemoryUses(cast(token), lastUses, seenTokens); + assert(!lastUses.empty()); + + if (lastUses.size() == 1 && lastUses.front().getDefiningOp()->getBlock() == + token.getDefiningOp()->getBlock()) + return lastUses.front(); + // We can handle this case as needed. Right now it never happens. + llvm::report_fatal_error( + "FIXME: can't hoist TMEM alloc with multiple or conditional uses"); +} + +ttng::TMEMAllocOp hoistTMEMAlloc(TMEMTokenAllocOp alloc, scf::ForOp &forOp) { + OpBuilder builder(alloc); + builder.setInsertionPoint(forOp); + Value vTrue = arith::ConstantIntOp::create(builder, alloc.getLoc(), 1, 1); + auto src = alloc.getSrc(); + auto newAlloc = cast(builder.clone(*alloc)); + newAlloc.getSrcMutable().clear(); + + // By hoisting the allocation out of the loop, we need to turn the underlying + // memory variable into a loop-carried depdendency. + auto tokType = builder.getType(); + forOp = addIterArgsToLoop(builder, forOp, newAlloc.getToken()); + Value newTok = forOp.getRegionIterArgs().back(); + appendToForOpYield(forOp, joinLastMemoryUses(builder, alloc.getToken())); + + if (src != nullptr) { + builder.setInsertionPoint(alloc); + // Write the initial value of the allocation and replace the token. + auto initStoreOp = + ttng::TMEMStoreOp::create(builder, alloc.getLoc(), tokType, + newAlloc.getResult(), newTok, src, vTrue); + newTok = initStoreOp.getToken(); + } + alloc.replaceAllUsesWith(ValueRange{newAlloc.getResult(), newTok}); + alloc.erase(); + + return newAlloc; +} + +// Hoist invariant tmem_alloc. This could technically be done as general LICM +// but controlling tmem liveranga more precisley is likely to be important. +static void hoistInvariantInputs(Operation *mmaOp, scf::ForOp forOp) { + for (auto operand : mmaOp->getOperands()) { + if (forOp.isDefinedOutsideOfLoop(operand)) + continue; + auto tmemAllocOp = operand.getDefiningOp(); + if (!tmemAllocOp || tmemAllocOp.getType().getMutableMemory()) + continue; + assert(tmemAllocOp.getSrc()); + Value src = tmemAllocOp.getSrc(); + SmallVector opToHoist = {tmemAllocOp.getOperation()}; + // Also hoist simple unary elementwise that may have sinked into the loop. + while (Operation *defOp = src.getDefiningOp()) { + if (forOp.isDefinedOutsideOfLoop(src)) + break; + if (!(isPure(defOp) && defOp->getNumOperands() == 1)) + break; + opToHoist.push_back(defOp); + src = defOp->getOperand(0); + } + if (!forOp.isDefinedOutsideOfLoop(src)) + continue; + for (auto op : llvm::reverse(opToHoist)) { + forOp.moveOutOfLoop(op); + } + } +} +} // namespace + +struct HoistTMEMAlloc + : public impl::TritonGPUHoistTMEMAllocBase { + using impl::TritonGPUHoistTMEMAllocBase< + HoistTMEMAlloc>::TritonGPUHoistTMEMAllocBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + if (!hoistOutOfIf) { + SmallVector mmaOps; + m.walk([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back(mmaOp); }); + for (auto mmaOp : mmaOps) { + auto forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + hoistInvariantInputs(mmaOp, forOp); + + // Only hoist the TMEM alloc feeding into the accumulator. Leave the + // ones for the scales in the loop. + auto alloc = mmaOp.getAccumulator().getDefiningOp(); + if (!alloc || alloc->getParentRegion() != mmaOp->getParentRegion()) { + continue; + } + hoistTMEMAlloc(alloc, forOp); + } + } + + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (hoistOutOfIf) { + patterns.add(&getContext()); + } + scf::ForOp::getCanonicalizationPatterns(patterns, &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + llvm_unreachable("Failed to hoist tmem_store"); + } + + // TODO: currently some code assumes that a mutable tmem alloc doesn't have + // an initial value. As a workaround we break up the op in order to keep + // this form for the downstream passes. We should remove this once the + // downstread passes are fixed. + m.walk([&](ttng::TMEMAllocOp alloc) { + if (alloc.getType().getMutableMemory() && alloc.getSrc()) { + OpBuilder builder(alloc); + builder.setInsertionPointAfter(alloc); + auto store = ttng::TMEMStoreOp::create( + builder, alloc.getLoc(), builder.getType(), + alloc.getResult(), alloc.getToken(), alloc.getSrc(), + arith::ConstantIntOp::create(builder, alloc.getLoc(), 1, 1)); + alloc.getToken().replaceAllUsesExcept(store.getToken(), store); + alloc.getSrcMutable().clear(); + } + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.cpp new file mode 100644 index 0000000000..8ab0a818dd --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.cpp @@ -0,0 +1,49 @@ +#include "triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include +#include + +namespace mlir::triton::gpu { + +std::optional> +inferSourceLoadLayout(const LinearLayout &dstLayout, Operation *defOp) { + if (!defOp) + return std::nullopt; + return inferSourceLoadLayout( + LinearEncodingAttr::get(defOp->getContext(), dstLayout), defOp); +} + +std::optional> +inferSourceLoadLayout(LinearEncodingAttr dstLayout, Operation *defOp) { + Attribute curLayout = dstLayout; + Operation *curOp = defOp; + while (curOp) { + if (isa(curOp)) + break; // Found the load op; we are done here. + + if (auto cvtOp = dyn_cast(curOp)) { + // For convert op we keep the current layout to push through further. + curOp = cvtOp.getSrc().getDefiningOp(); + } else { + if (curOp->getNumOperands() != 1) + break; + curLayout = inferSrcEncoding(curOp, curLayout); + curOp = curOp->getOperand(0).getDefiningOp(); + } + } + auto loadOp = dyn_cast_or_null(curOp); + if (!loadOp) + return std::nullopt; + auto loadType = dyn_cast(loadOp.getType()); + if (!loadType) + return std::nullopt; + + return std::make_pair( + loadOp, + toLinearLayout(loadType.getShape(), cast(curLayout))); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 0000000000..05468b9aab --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,311 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +class TMEMAllocWithUnusedInit + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + if (op.getSrc() == nullptr) + return failure(); + SmallVector users(op.getResult().getUsers().begin(), + op.getResult().getUsers().end()); + if (users.size() > 2) + return failure(); + triton::nvidia_gpu::MMAv5OpInterface mmaOp = nullptr; + triton::nvidia_gpu::TMEMLoadOp tmemLoad = nullptr; + for (auto user : users) { + if (auto load = dyn_cast(user)) { + tmemLoad = load; + } else if (auto mma = + dyn_cast(user)) { + mmaOp = mma; + } + } + if (!mmaOp) + return failure(); + if (tmemLoad && !mmaOp->isBeforeInBlock(tmemLoad)) + return failure(); + Value useAccFlag = mmaOp.useAccumulator(); + if (!useAccFlag) + return failure(); + auto flagConstOp = useAccFlag.getDefiningOp(); + if (!flagConstOp) + return failure(); + if (cast(flagConstOp.getValue()).getInt() != 0) + return failure(); + op.getSrcMutable().clear(); + return success(); + } +}; + +bool dotSupportsAccInitFlag(Operation *op) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + // Partial accumulation would require a select op to handle the + // initialization that would degrade the performance. + return !wgDotOp.needsPartialAccumulator(); + } + if (isa(op)) { + return true; + } + return false; +} + +std::pair getAccumulatorUseAndDef(Operation *op) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + return std::make_pair(wgDotOp.getC(), wgDotOp); + } + if (auto tc05MmaOp = dyn_cast(op)) { + auto accVal = tc05MmaOp.getAccumulator(); + auto tmemAlloc = accVal.getDefiningOp(); + if (!tmemAlloc || + tmemAlloc->getParentRegion() != tc05MmaOp->getParentRegion()) + return std::make_pair(nullptr, nullptr); + triton::nvidia_gpu::TMEMLoadOp tmemLoad = nullptr; + for (auto user : tmemAlloc.getResult().getUsers()) { + if (auto load = dyn_cast(user)) { + tmemLoad = load; + break; + } + } + if (!tmemLoad || + tmemLoad->getParentRegion() != tc05MmaOp->getParentRegion()) + return std::make_pair(nullptr, nullptr); + return std::make_pair(tmemAlloc.getSrc(), tmemLoad); + } + assert(false && "Unexpected op which implements a DotOpInterface"); + return std::make_pair(nullptr, nullptr); +} + +void setUseAccFlag(Operation *op, Value useAcc) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + wgDotOp.getUseCMutable().assign(useAcc); + } else if (auto tc05MmaOp = + dyn_cast(op)) { + tc05MmaOp.setUseAccumulator(useAcc); + } else { + assert(false && "Unexpected op which implements a DotOpInterface"); + } +} + +Value getUseAccFlag(Operation *op) { + assert(isa(op) && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + return wgDotOp.getUseC(); + } else if (auto tc05MmaOp = + dyn_cast(op)) { + return tc05MmaOp.useAccumulator(); + } else { + assert(false && "Unexpected dot-like operation"); + } + return nullptr; +} + +bool isConstantZeroTensor(Value v) { + return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat())); +} + +std::optional> +findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) { + loopArgIsZero = true; + } + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) { + return std::nullopt; + } + if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = cast(v).getResultNumber(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + // Make sure that the other value is not defined in the if itself, but + // passed from outside + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +class OptimizeAccumulatorInitPass + : public impl::TritonGPUOptimizeAccumulatorInitBase< + OptimizeAccumulatorInitPass> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + SmallVector mmaOps; + m.walk([&](Operation *op) { + if (isa(op) && dotSupportsAccInitFlag(op)) + mmaOps.push_back(op); + }); + + // for each mma op, find where the accumulator is initialized with zero + // It can be: + // 1. A constant zero + // 2. Initialized with zero as the loop argument + // 3. Initialized with zero in the if op or with a select op in current + // or any of the previous loop iterations + for (Operation *mmaOp : mmaOps) { + Location loc = mmaOp->getLoc(); + + scf::ForOp forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); + Value vFalse = + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)); + + // Find the accumulator + auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp); + if (!accUse || !accDef) { + continue; + } + if (isConstantZeroTensor(accUse)) { + setUseAccFlag(mmaOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, forOp, loopArgIsZero); + + if (!zeroInitOp && !loopArgIsZero) { + continue; + } + + if (auto useAccValue = getUseAccFlag(mmaOp)) { + auto useAcc = getBoolFromConstant(useAccValue); + if (!useAcc || *useAcc == false) { + // Do not run this optimization if there is already a non-constant + // flag (this pass has already run), or if this MMA does not use the + // accumulator (e.g. the peeled MMA in the prologue, the first dot + // in attention) + continue; + } + } + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + forOp = addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue}); + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + if (zeroInitOp) { + Value condition = nullptr; + Value oldValue = nullptr; + Value zeroValue = nullptr; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + zeroValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getTrueValue() + : selOp.getFalseValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + assert(isa(*zeroInitOp->first) && "Expected an if op"); + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + // Create a select op that updates the flag + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp); + Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue; + auto selectFlagOp = arith::SelectOp::create( + rewriter, loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseAccFlag(mmaOp, + zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeMMA ? vTrue : selectFlagOp}); + + // Stop clearing out the accumulator with zero + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } else if (loopArgIsZero) { + setUseAccFlag(mmaOp, loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), vTrue); + } + } + + // Cleanup unused init values in tmem allocs + mlir::RewritePatternSet patterns(m.getContext()); + patterns.add(m.getContext()); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp new file mode 100644 index 0000000000..19977d4856 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -0,0 +1,358 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include + +namespace mlir::triton::gpu { + +namespace { +// Given +// dot(convert(trans(src)) #dot_operand) -> +// dot(convert(local_load(trans(alloc(src))))) +// change the encoding of the inner convert to a special, swizzled shared +// encoding. +class SwizzleShmemConvert : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, + PatternRewriter &rewriter) const override { + if (!cvtOp->hasOneUse() || + !isa(cvtOp->use_begin()->getOwner())) + return failure(); + // Match outerCvt(trans(innerCvt(x))). + auto trans = cvtOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef{1, 0}) + return failure(); + // Only rewrite when transpose is exclusively consumed by this convert. + if (!trans->hasOneUse()) + return failure(); + + RankedTensorType srcTy = trans.getSrc().getType(); + + if (auto srcCvt = trans.getSrc().getDefiningOp()) { + srcTy = srcCvt.getSrc().getType(); + } + RankedTensorType sharedLoadTy = cvtOp.getType(); + auto cvtEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!cvtEncoding) + return failure(); + + // Set needTrans to true here. newInnerCvtEnc is computed based on + // argEncoding which is before the transpose. Without needTrans we will + // compute vec and maxPhase based on incorrect m, n and k size of mma. The + // type inference of MemDescTransOp simply swap the order but doesn't fix + // the vec and maxPhase for the YType, hence it would causing incorrect + // swizzling code. + auto ctx = getContext(); + auto oldCGALayout = triton::gpu::getCGALayout(srcTy.getEncoding()); + auto newLl = + transposeLinearLayout(oldCGALayout.getLinearLayout(), trans.getOrder()); + auto newCGALayout = CGAEncodingAttr::get(ctx, std::move(newLl)); + auto newInnerCvtEnc = + SwizzledSharedEncodingAttr::get(ctx, cvtEncoding, srcTy.getShape(), + /*order=*/getOrderForMemory(srcTy), + newCGALayout, srcTy.getElementType(), + /*needTrans=*/true); + if (newInnerCvtEnc == cvtEncoding) + return failure(); + rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); + auto alloc = LocalAllocOp::create( + rewriter, trans.getLoc(), + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + newInnerCvtEnc, sharedMemorySpace), + trans.getSrc()); + auto newTrans = MemDescTransOp::create(rewriter, trans.getLoc(), alloc, + ArrayRef({1, 0})); + auto localLoadOp = + LocalLoadOp::create(rewriter, trans.getLoc(), sharedLoadTy, newTrans); + rewriter.modifyOpInPlace(cvtOp, [&]() { + cvtOp.getSrcMutable().assign(localLoadOp.getResult()); + }); + return success(); + } +}; + +// Rewrite +// +// dot(alloc(trans() #shared1) -> +// dot(trans(alloc() #shared2)) +// +// if dot is an MMAv3/v5 (because MMAv3/v5 allows us to fold transposes). +class FuseTransMMAV3Plus : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp.getSrc() || !allocOp->hasOneUse() || + !isa( + *allocOp->getUsers().begin())) + return failure(); + + auto dot = *allocOp->getUsers().begin(); + if (auto dotTy = dyn_cast(dot->getResult(0).getType())) { + if (isa(dotTy.getEncoding())) + return failure(); + } + // Match outerCvt(trans(innerCvt(x))). + auto trans = allocOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef({1, 0})) + return failure(); + + MemDescType allocType = allocOp.getType(); + auto allocEncoding = cast(allocType.getEncoding()); + RankedTensorType srcTy = trans.getSrc().getType(); + + auto ctx = getContext(); + Dialect &dialect = allocEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + Attribute newInnerEnc; + if (failed(inferLayoutInterface->inferTransOpEncoding( + allocEncoding, srcTy.getShape(), trans.getOrder(), newInnerEnc, + allocOp.getLoc()))) { + return failure(); + } + + MemDescType innerTy = + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, + allocType.getMemorySpace()); + auto newAlloc = LocalAllocOp::create(rewriter, allocOp.getLoc(), innerTy, + trans.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); + return success(); + } +}; + +// Rewrite +// +// alloc(reshape(), #shared1) -> +// memdesc_reshape(alloc() #shared2)) +// +// if dot is an MMAv3/v5 (because MMAv3/v5 allows us to fold transposes). +class ReshapeMemDesc : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp.getSrc()) + return failure(); + + auto reshapeOp = allocOp.getSrc().getDefiningOp(); + if (!reshapeOp) + return failure(); + + MemDescType allocType = allocOp.getType(); + auto allocEncoding = allocType.getEncoding(); + + RankedTensorType srcTy = reshapeOp.getSrc().getType(); + auto srcShape = srcTy.getShape(); + auto dstShape = allocType.getShape(); + + // We use the fact that forward and backward inference are the same for + // MemDescReshapeOp to infer the source MemDescType that would produce + // `allocType` after a reshape. + MemDescType innerTy; + if (failed(MemDescReshapeOp::inferReturnTypes( + getContext(), allocOp.getLoc(), allocType, srcShape, innerTy))) + return failure(); + + // For now don't apply the transformation if the new encoding is not an + // MMAv3/v5 encoding as it may not be compatible with the user. + // The heuristic can be refined once we have more flexible mma ops. + if (!isa(innerTy.getEncoding())) + return failure(); + + auto newAlloc = LocalAllocOp::create(rewriter, allocOp.getLoc(), innerTy, + reshapeOp.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, allocOp.getType(), + newAlloc); + return success(); + } +}; + +// Inject TMEM copy instructions into IR to efficiently load blocked scales for +// scaled dot +class UseShmemForScales + : public OpRewritePattern { +public: + using OpRewritePattern< + triton::nvidia_gpu::TCGen5MMAScaledOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAScaledOp mmaOp, + PatternRewriter &rewriter) const override { + auto aScale = mmaOp.getAScale(); + auto bScale = mmaOp.getBScale(); + LogicalResult ret = failure(); + if (aScale && isa( + aScale.getType().getEncoding())) { + if (rewriteOperand(mmaOp.getAScaleMutable(), rewriter).succeeded()) + ret = success(); + } + if (bScale && isa( + bScale.getType().getEncoding())) { + if (rewriteOperand(mmaOp.getBScaleMutable(), rewriter).succeeded()) + ret = success(); + } + return ret; + } + +private: + LogicalResult rewriteOperand(OpOperand &opOperand, + PatternRewriter &rewriter) const { + auto src = cast>(opOperand.get()); + auto tmemAlloc = src.getDefiningOp(); + if (!tmemAlloc) { + return failure(); + } + auto dstType = tmemAlloc.getResult().getType(); + + if (!tmemAlloc.getSrc()) { + return failure(); + } + + // Look for a sequence + // local_load + // -> reshape(..., (BLOCK_MN / 128, BLOCK_K / scale_vec_size / 4, 32, 4, + // 4) + // -> transpose(..., (0, 3, 2, 1, 4)) + // -> reshape(..., (BLOCK_MN, BLOCK_K / scale_vec_size) + // -> tmem_alloc + // -> tc_gen_mma_scaled + // and replace it with local_alloc -> tc_gen_mma_scaled + auto scale2DShape = dstType.getShape(); + auto blockMN = scale2DShape[0]; + auto numScales = scale2DShape[1]; + const SmallVector transposeOrder{0, 3, 2, 1, 4}; + const SmallVector reshape5DShape{blockMN / 128, numScales / 4, 32, + 4, 4}; + + auto reshapeOp2D = getNextOp(tmemAlloc.getSrc()); + if (!reshapeOp2D || + reshapeOp2D.getResult().getType().getShape() != scale2DShape) { + return failure(); + } + + auto transOp = getNextOp(reshapeOp2D.getSrc()); + if (!transOp || transOp.getOrder() != ArrayRef(transposeOrder)) { + return failure(); + } + + auto reshapeOp5D = getNextOp(transOp.getSrc()); + if (!reshapeOp5D || reshapeOp5D.getResult().getType().getShape() != + ArrayRef(reshape5DShape)) { + return failure(); + } + + auto localLoad = getNextOp(reshapeOp5D.getSrc()); + if (!localLoad) { + return failure(); + } + auto localAlloc = getNextOp(localLoad.getSrc()); + bool usesTMAload = + (localAlloc && localAlloc.getSrc() && + (getNextOp(localAlloc.getSrc()) != nullptr)); + if (!isTmemCopyCompatible(localLoad.getSrc().getType(), usesTMAload)) + return failure(); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(tmemAlloc); + + Value shared = localLoad.getSrc(); + + Value reshaped5D = MemDescReshapeOp::create(rewriter, reshapeOp5D.getLoc(), + shared, reshape5DShape); + SmallVector transposeOrder32(transposeOrder.begin(), + transposeOrder.end()); + Value transposed = MemDescTransOp::create(rewriter, transOp.getLoc(), + reshaped5D, transposeOrder32); + SmallVector scale2DShapeVec(scale2DShape.begin(), + scale2DShape.end()); + Value reshaped2D = MemDescReshapeOp::create(rewriter, reshapeOp2D.getLoc(), + transposed, scale2DShapeVec); + + opOperand.assign(reshaped2D); + rewriter.eraseOp(tmemAlloc); + return success(); + } + + template Op getNextOp(Value op) const { + while (auto cvtOp = op.getDefiningOp()) { + op = cvtOp.getSrc(); + } + return op.getDefiningOp(); + } + + bool isTmemCopyCompatible(triton::gpu::MemDescType scaleType, + bool usesTMAload) const { + // TMEM copy expects that blocked scale "chunks" in SMEM are stored in + // innermost axes contiguously. + if (!isInnermostContiguous(scaleType, 512)) + return false; + + if (usesTMAload) { + return true; + } + + if (scaleType.getRank() != 2) { + // TODO: Add support for higher rank when 5D coalesced load is fixed + return false; + } + + auto elemBits = scaleType.getElementType().getIntOrFloatBitWidth(); + + // We assume that 32x128b chunks are flattened into the inner most axis. + auto innerMostBits = + scaleType.getDimSize(scaleType.getRank() - 1) * elemBits; + return innerMostBits % (32 * 128) == 0; + } +}; + +} // namespace + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUOptimizeDotOperandsPass + : public impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass> { +public: + using impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass>::TritonGPUOptimizeDotOperandsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + OpPassManager pm; + pm.addPass(mlir::createCanonicalizerPass()); + if (failed(runPipeline(pm, m))) + return signalPassFailure(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 0000000000..327cb4a240 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,583 @@ +#include +#include + +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +// Change the destination layout of reshape ops allowing reorder when used by a +// reduction in order to minimize the amount of cross thread communication for +// the reduction. +struct OptimizeReshapeLayoutPattern : public OpRewritePattern { + OptimizeReshapeLayoutPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(ReshapeOp viewOp, + PatternRewriter &rewriter) const override { + if (!viewOp.getAllowReorder()) + return failure(); + std::optional reductionAxis; + for (Operation *user : viewOp.getResult().getUsers()) { + if (auto reduceOp = dyn_cast(user)) { + if (reductionAxis) { + if (reductionAxis != reduceOp.getAxis()) + return failure(); + } else { + reductionAxis = reduceOp.getAxis(); + } + } + } + if (!reductionAxis) + return failure(); + RankedTensorType tensorType = viewOp.getType(); + if (auto blocked = + mlir::dyn_cast(tensorType.getEncoding())) { + // If the layout already has all the elements along the reduction + // dimension in the same thread we can skip. + if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && + blocked.getWarpsPerCTA()[*reductionAxis] == 1 && + blocked.getCGALayout().getCTAsPerCGA()[*reductionAxis] == 1) + return failure(); + } + ArrayRef shape = tensorType.getShape(); + SmallVector order; + for (int i : triton::gpu::getOrder(tensorType)) { + if (i != *reductionAxis) + order.push_back(i); + } + // Make the reduction axis last so that elements won't be distributed + // amongst threads along this dimension. + order.push_back(*reductionAxis); + SmallVector sizePerThread(shape.size(), 1); + auto mod = viewOp->getParentOfType(); + int numWarps = lookupNumWarps(viewOp); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + auto encoding = + BlockedEncodingAttr::get(viewOp.getContext(), shape, sizePerThread, + order, numWarps, threadsPerWarp, numCTAs); + if (encoding == tensorType.getEncoding()) + return failure(); + RankedTensorType newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType)) + return failure(); + rewriter.setInsertionPointAfter(viewOp); + rewriter.modifyOpInPlace(viewOp, [&]() { + viewOp.getResult().setType(newType); + viewOp.setEfficientLayout(true); + }); + auto cvt = ConvertLayoutOp::create(rewriter, viewOp.getLoc(), tensorType, + viewOp.getResult()); + rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt); + return success(); + } +}; +} // namespace + +// This function considers a gather op in isolation and attempts to determine +// whether an optimized layout can be applied to the source and index tensors. +static LogicalResult setOptimizedGatherLayout(GatherOp op, RewriterBase &b) { + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Determine a warp-local gather layout that minimizes the number of emitted + // warp shuffles. + unsigned numThreadsPerWarp = lookupThreadsPerWarp(b); + unsigned numWarps = lookupNumWarps(op); + + // If in a gather column, each thread owns `srcSizePerThread[axis]` elements + // in the source tensor and `idxSizePerThread[axis]` elements in the index + // tensor (including broadcasting), then the number of index shuffles per + // column is `srcSizePerThread[axis] * idxSizePerThread[axis]`. This is then + // replicated over the number of columns in which a thread owns (an equal + // number of) elements, which is `product(srcSizePerThread[i] for i != axis)`. + // + // Thus, the total number of index shuffles is `product(srcSizePerThread) * + // idxSizePerThread[axis]`. Since we cannot alter the number of threads per + // warp or the number of warps, `product(srcSizePerThread)` is just a function + // of the shape. + // + // So we want to minimize `idxSizePerThread[axis]`. Note that broadcasting is + // forbidden in the source tensor but allowed in the index tensor. Choose the + // smallest value while still ensuring that a warp spans whole columns. + // + // In order to prevent broadcasting in the source tensor layout, ensure + // + // sizePerThread(i) * threadsPerWarp(i) * warpsPerCTA(i) = shape(i) + // + // For all i != axis in the source tensor. The same relationship must hold for + // the index tensor. This means we can't just set `idxSizePerThread[axis]` to + // 1 and compute the rest from that. Find the smallest value where this + // relationship is still respected. + + // We know that the layouts will be the same between the two tensors except + // for `sizePerThread[axis]`. + unsigned axis = op.getAxis(); + unsigned rank = srcType.getRank(); + if (rank == 1) + return failure(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector order; + order.push_back(axis); + + // Minimize `sizePerThread[axis]` by putting as many theads along the axis as + // possible, limited to the actual size of the dimension. + unsigned maxThreadsInAxis = + std::min(srcType.getDimSize(axis), numThreadsPerWarp); + threadsPerWarp[axis] = maxThreadsInAxis; + + // Now spread them along the other dimensions. Do this according to order + // (arbitrary). + unsigned threadsToAlloc = numThreadsPerWarp / maxThreadsInAxis; + for (unsigned dim : getThreadOrder(srcType)) { + if (dim == axis) + continue; + // The gather axis is now the fastest-changing dimension. + order.push_back(dim); + unsigned nextThreadAlloc = + std::min(srcType.getDimSize(dim), threadsToAlloc); + threadsPerWarp[dim] = nextThreadAlloc; + threadsToAlloc /= nextThreadAlloc; + } + assert(llvm::none_of(threadsPerWarp, [](unsigned c) { return c == 0; })); + + // There must be one warp along the gather axis. + warpsPerCTA[axis] = 1; + // Allocate the remaining warps in the same manner. + unsigned warpsToAlloc = numWarps; + for (unsigned dim : getWarpOrder(srcType)) { + if (dim == axis) + continue; + unsigned warpsCanFit = srcType.getDimSize(dim) / threadsPerWarp[dim]; + assert(warpsCanFit != 0); + unsigned nextWarpAlloc = std::min(warpsCanFit, warpsToAlloc); + warpsPerCTA[dim] = nextWarpAlloc; + warpsToAlloc /= nextWarpAlloc; + } + assert(llvm::none_of(warpsPerCTA, [](unsigned c) { return c == 0; })); + + // Just set `sizePerThread` to 1 along other dimensions and let broadcasting + // handling it. This also means we can use the same layout between the source + // and index tensors for simplicity. + SmallVector sizePerThread(rank, 1); + sizePerThread[axis] = srcType.getDimSize(axis) / threadsPerWarp[axis]; + + // Overflow by broadcasting along the gather axis since this is the most + // predictable. + threadsPerWarp[axis] *= threadsToAlloc; + warpsPerCTA[axis] *= warpsToAlloc; + + assert(product(threadsPerWarp) == numThreadsPerWarp); + assert(product(warpsPerCTA) == numWarps); + + // Construct the new layout. + MLIRContext *ctx = srcType.getContext(); + auto baseLayout = cast(srcType.getEncoding()); + auto cgaLayout = getCGALayout(baseLayout); + auto newLayout = BlockedEncodingAttr::get(ctx, sizePerThread, threadsPerWarp, + warpsPerCTA, order, cgaLayout); + + // Update the layout on the gather op and insert conversions. + auto cvtSrc = ConvertLayoutOp::create( + b, op.getLoc(), srcType.cloneWithEncoding(newLayout), op.getSrc()); + auto cvtIdx = ConvertLayoutOp::create( + b, op.getLoc(), idxType.cloneWithEncoding(newLayout), op.getIndices()); + + b.setInsertionPointAfter(op); + auto cvtOut = + ConvertLayoutOp::create(b, op.getLoc(), op.getType(), op.getResult()); + b.replaceAllUsesExcept(op.getResult(), cvtOut, cvtOut); + + b.modifyOpInPlace(op, [&] { + op.getSrcMutable().set(cvtSrc); + op.getIndicesMutable().set(cvtIdx); + op.getResult().setType(op.getType().cloneWithEncoding(newLayout)); + + // Mark the layout as optimized on the op to prevent it from being changed. + op.setEfficientLayout(true); + }); + + // Make sure we did this right. + assert(GatherLoweringHelper(op).isWarpLocal()); + + return success(); +} + +namespace { +struct OptimizeGatherLayoutPattern : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const override { + if (op.getEfficientLayout()) + return failure(); + return setOptimizedGatherLayout(op, rewriter); + } +}; +} // namespace + +namespace { +class TritonGPUOptimizeThreadLocalityPass + : public impl::TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First try to optimize the layout of views and gathers. + mlir::RewritePatternSet layoutPatterns(&getContext()); + layoutPatterns.add(&getContext()); + layoutPatterns.add(&getContext()); + if (mlir::applyPatternsGreedily(mod, std::move(layoutPatterns)).failed()) { + signalPassFailure(); + } + + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa( + reductionOp.value())) + return; + // TODO: relax this restriction + if (!(isa(srcEncoding) && rank > 1)) + return; + // The code currently assumes that the reduction is happening on the most + // inner dim. + if (reduce.getAxis() != rank - 1) + return; + for (auto operand : reduce->getOperands()) { + if (!operand.getDefiningOp()) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = oldAccum.getDefiningOp(); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + IRRewriter builder(&getContext()); + for (auto reduce : reduceOps) { + builder.setInsertionPoint(reduce); + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(isa(srcEncoding) && + "Thread locality optimization only supports blocked encoding"); + auto blocked = dyn_cast(srcEncoding); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(isa(accumOperand)); + auto blockArg = dyn_cast(accumOperand); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = triton::gpu::ConvertLayoutOp::create( + builder, newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + scf::YieldOp::create(builder, oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = triton::ReshapeOp::create( + builder, reduce.getLoc(), viewOpTensorType, operand, + /*allowReorder=*/true, /*efficientLayout=*/true); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + // Work around the lack of support for MaxNumFOp and MinNumFOp in + // arith::getNeutralElement. + std::optional getNeutralElement(Operation *op) const { + if (isa(op)) { + OpBuilder builder(op->getContext()); + + Type resultType = op->getResult(0).getType(); + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/true)); + } + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/false)); + } + } else { + return mlir::arith::getNeutralElement(op); + } + llvm_unreachable("Unhandled reduction op"); + return std::nullopt; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = cast(oldAccum.getType()).getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = arith::ConstantOp::create(builder, oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + BlockedEncodingAttr + getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = dyn_cast(srcEncoding); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctaLl = blocked.getCGALayout().getLinearLayout(); + auto kBlocked = *ctaLl.getInDimNames().begin(); + auto *ctx = kBlocked.getContext(); + auto dim = standardOutDimNames(ctx, rank + 1)[rank]; + ctaLl *= LinearLayout::identity1D(1, kBlocked, dim); + auto ctaLayout3d = CGAEncodingAttr::get(ctx, std::move(ctaLl)); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } + template + SmallVector insertValue(const SmallVector &vec, unsigned index, + int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; +} // namespace + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp new file mode 100644 index 0000000000..16f61e1ca4 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -0,0 +1,395 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir::triton::gpu { +namespace { + +//===----------------------------------------------------------------------===// +// assignLatencies +//===----------------------------------------------------------------------===// + +inline constexpr llvm::StringLiteral kDisableGenericDotPipelineAttr = + "tt.disable_generic_dot_pipeline"; + +// Return true if the preconditions for pipelining the loop are met. +bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + return true; +} + +bool hasLatenciesAssigned(scf::ForOp forOp) { + auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (helper.getAttr(&op)) + return true; + } + return false; +} + +void assignUserProvidedLatencies(scf::ForOp forOp, + DenseMap &opLatency) { + auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto latencyAttr = helper.getAttr(&op)) { + opLatency[&op] = latencyAttr.getInt(); + } + } +} + +static bool supportsGenericDotLoadPipelining(Operation *op) { + if (auto disableAttr = + op->getAttrOfType(kDisableGenericDotPipelineAttr); + disableAttr && disableAttr.getValue()) + return false; + auto dotOp = dyn_cast(op); + if (!dotOp) + return false; + return isa(dotOp.getA().getType()) && + isa(dotOp.getB().getType()); +} + +class AssignLoadLatencies { +public: + AssignLoadLatencies(scf::ForOp forOp, int numStages, + DenseMap &opLatency) + : forOp(forOp), numStages(numStages), opLatency(opLatency) {}; + + void run() { + bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + llvm::MapVector> loadOpToIndLevel = + loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, axisInfoAnalysis, + numStages); + if (loadOpToIndLevel.empty()) + return; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = 0; + for (auto &[loadOp, info] : loadOpToIndLevel) + maxIndirectionLevel = std::max(maxIndirectionLevel, info.first); + unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1); + + for (auto [loadOp, dist] : loadOpToIndLevel) { + opLatency[loadOp] = loadLatency; + } + } + +private: + scf::ForOp forOp; + int numStages; + DenseMap &opLatency; + +public: + static bool canHaveSharedEncoding(tt::LoadOp op) { + // If used by an user with DotOp encoding, all the uses must be compatible. + bool incompatible = false; + getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); + return !incompatible; + } + + static bool + isPipeliningBeneficial(Operation *op, Operation *finalUser, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool filterSmall) { + if (auto loadOp = dyn_cast(op)) { + if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) { + LDBG("Load " << *loadOp << " is too small for pipelining"); + return false; + } + } + if (isa(op)) + return true; + if (!canHaveSharedEncoding(cast(op))) { + LDBG("Load " << *op << " cannot have shared encoding"); + return false; + } + + ttg::SharedEncodingTrait localAllocEnc; + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) { + // If the load is used by a LocalAllocOp, all the users need to have + // the same encoding. + return false; + } + } + } + + if (localAllocEnc) { + auto registerTy = cast(op->getResultTypes()[0]); + auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc); + if (filterSmall && vecBytes < 4) { + // At least 4 bytes need to be consecutive for cp.async + return false; + } + } + + return true; + } +}; + +class AssignMMALatencies { +public: + AssignMMALatencies(scf::ForOp forOp, DenseMap &opLatency) + : forOp(forOp), opLatency(opLatency) {}; + + void run() { + DenseMap mmaSelfLatency; + // Check if the load op (mma operand) is pipelineable. + auto isLoadToBePipelined = [&](Operation *op) { + return opLatency.count(op) && opLatency[op] > 0; + }; + for (auto &op : forOp.getBody()->without_terminator()) { + // If the acc can not be multibuffered, do not pipeline the uses of + // the MMA to later stages. + if (auto mma = dyn_cast(&op)) { + // Try to push out the wait by one stage even if the operands are not + // pipelineable, but we know where the loads are scheduled, so we can + // place the wait right before the loads. + + if (hasSyncDots(forOp)) { + // Skip pipelining MMA in the loops where sync dots are used. This + // is a dirty heuristic for performance drops in kernels where we + // would rather want to have last iteration peeled instead of having a + // full iteration of masked operations only to execute single wait. + continue; + } + auto pipeHelper = ttng::MMAv5PipelineableOperandsHelper( + mma, forOp, isLoadToBePipelined); + if (pipeHelper.isPipelineable || + (pipeHelper.isOperandsStateDetermined && + !ttng::hasLoadsAfterMMA(mma, forOp))) { + // MMA can be overlapped with itself + mmaSelfLatency[mma] = 1; + if (!ttng::requiresAccMultiBuffering(mma, forOp) || + (ttng::isAccMultibufferingPossible(mma, forOp) && + !getDisallowAccMultiBuffer(forOp))) { + // MMA's users can be pushed to the next stage + opLatency[&op] = 1; + } + // HACK: A pipelined MMA's latency should equal the number of buffers + // for the accumulator, but when the user is in an `scf.if` in SWP, + // the `scf.if` is pushed to the end of the loop rather than peeled + // before the MMA op, requiring an extra buffer due to liverange + // overlap. WS does not have this problem because the MMA is placed in + // a different partition than the MMA, so we can correctly set the + // latency. + if (isWarpSpecialized(forOp)) { + if (ttng::hasAccReadModifyWrite(mma, forOp)) + opLatency.erase(&op); // can't pipeline the MMA + else + opLatency[&op] += 1; + // If all inputs to the MMA are warp specialized, set the self + // latency to 0 since the MMA won't need to wait on itself. + auto cantWarpSpec = [](Operation *op) { return isa(op); }; + auto warpSpecHelper = ttng::MMAv5PipelineableOperandsHelper( + mma, forOp, [&](Operation *op) { + return isLoadToBePipelined(op) && !cantWarpSpec(op); + }); + if (warpSpecHelper.isPipelineable || + (warpSpecHelper.isOperandsStateDetermined && + llvm::none_of(warpSpecHelper.unpipelineableOperandDefs, + cantWarpSpec))) + mmaSelfLatency[mma] = 0; + } + } + } + } + serializeSelfLatencies(forOp->getParentOfType(), mmaSelfLatency); + } + +private: + scf::ForOp forOp; + DenseMap &opLatency; + + bool hasSyncDots(scf::ForOp forOp) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + return true; + } + return false; + } + + bool isWarpSpecialized(scf::ForOp forOp) { + scf::ForOp current = forOp; + do { + if (current->hasAttr(kWarpSpecializeAttrName)) { + return true; + } + current = current->getParentOfType(); + } while (current); + return false; + }; +}; + +// Discover operations that should become async and assign latencies to them +// based on the numStages value provided by the user. +// +// Look for load ops that directly or indirectly feed into dot ops. Based on the +// requested number of stages assign the latencies in a way that cover all the +// stages with the sum of latencies in the chain from the first load to the +// final dot op. +void assignLatencies(ModuleOp moduleOp, int defaultNumStages) { + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (preCondition(forOp) && + getNumStagesOrDefault(forOp, defaultNumStages) > 1) + loops.push_back(forOp); + }); + if (loops.empty()) + return; + + DenseMap opLatency; + for (auto forOp : loops) { + if (hasLatenciesAssigned(forOp)) { + assignUserProvidedLatencies(forOp, opLatency); + continue; + } + int numStages = getNumStagesOrDefault(forOp, defaultNumStages); + AssignLoadLatencies(forOp, numStages, opLatency).run(); + AssignMMALatencies(forOp, opLatency).run(); + } + serializeLatencies(moduleOp, opLatency); +} + +} // namespace + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +llvm::MapVector> +loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, + int numStages, bool filterSmall) { + llvm::MapVector> loadOpToIndLevel; + DenseSet seen; + DenseSet excluded; + + std::function dfs = + [&](Operation *op, Operation *finalUser, int distance) { + if (!seen.insert(op).second || excluded.count(op)) + return; + if (isa(op)) { + if (!AssignLoadLatencies::isPipeliningBeneficial( + op, finalUser, axisInfoAnalysis, filterSmall)) + return; + if (loadOpToIndLevel.count(op)) { + int level = loadOpToIndLevel[op].first; + if (level != distance) { + // If we have multiple uses at different distances, we don't + // know which one to pick. + LDBG("Load " << *op + << " has multiple uses at different distances:" + << level << " and " << distance); + loadOpToIndLevel.erase(op); + excluded.insert(op); + return; + } + } else { + LDBG("Load " << *op << " considered for pipelining with distance " + << distance); + loadOpToIndLevel[op] = {distance, finalUser}; + } + finalUser = op; + distance++; + } + for (Value operand : getNestedOperands(op)) { + if (supportsGenericDotLoadPipelining(op)) { + // Heuristic: only pipeline A and B operands of the dot op. + if (operand == op->getOperand(2)) + continue; + } + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, finalUser, distance); + } + } + }; + + bool seenDot = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + // Arbitrary heuristic. TMEMStoreOp is included to keep logic consistent + // with legacy code when we weren't hoisting tmem allocas. + if (!supportsGenericDotLoadPipelining(&op) && !isa(op)) + continue; + seenDot = true; + seen.clear(); + dfs(&op, &op, 0); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (pipelineWithoutDot) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, &op, 0); + } + } + + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + for (auto iter = loadOpToIndLevel.begin(); iter != loadOpToIndLevel.end();) { + if (iter->second.first >= numStages - 1) + iter = loadOpToIndLevel.erase(iter); + else + ++iter; + } + + return loadOpToIndLevel; +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_TRITONGPUASSIGNLATENCIES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct AssignLatencies + : public impl::TritonGPUAssignLatenciesBase { + using TritonGPUAssignLatenciesBase::TritonGPUAssignLatenciesBase; + + void runOnOperation() override { assignLatencies(getOperation(), numStages); } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp new file mode 100644 index 0000000000..1b11165097 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -0,0 +1,1184 @@ +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Dominance.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +///////////////////////////// +// UTILS +///////////////////////////// + +int getSelfLatencyFromAttr(Operation *op) { + auto module = op->getParentOfType(); + auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper(); + if (!helper.isAttrPresent(op)) + return 0; + int val = helper.getAttr(op).getInt(); + helper.removeAttr(op); + return val; +} + +// Check if the load can be pipelined entirely in shared memory, +// or if we need to load to registers. +bool mustLoadToRegisters(Operation *op) { + if (auto loadOp = dyn_cast(op)) { + // AsyncCopyGlobalToLocalOp does not support the non-zero "other" value. + // With consumer consuming directly the shared memory, there would be no way + // to replace masked values with the "other" value. + if (loadOp.getOther() && !isZeroConst(loadOp.getOther())) + return true; + } + + if (!op->hasOneUse()) + return true; + Operation *user = *op->getUsers().begin(); + auto alloc = dyn_cast(user); + if (!alloc) + return true; + + Attribute loadEncoding; + if (auto descLoad = dyn_cast(op)) { + loadEncoding = nvidia_gpu::getEncodingFromDescriptor(op, descLoad.getType(), + descLoad.getDesc()); + } else if (auto descGather = dyn_cast(op)) { + loadEncoding = nvidia_gpu::getEncodingFromDescriptor( + op, descGather.getType(), descGather.getDesc()); + } + return loadEncoding && (loadEncoding != alloc.getType().getEncoding()); +} + +static ttg::SharedEncodingTrait +getMusaSqmmaPipelinedSharedEncoding(RankedTensorType tensorTy) { + auto cgaLayout = ttg::getCGALayout(tensorTy.getEncoding()); + auto order = ttg::getOrder(tensorTy); + auto fallback = [&]() -> ttg::SharedEncodingTrait { + return ttg::SwizzledSharedEncodingAttr::get(tensorTy.getContext(), 1, 1, 1, + order, cgaLayout); + }; + if (order.empty()) + return fallback(); + + int elemBitWidth = tensorTy.getElementTypeBitWidth(); + if (elemBitWidth <= 0) + return fallback(); + + auto shapePerCTA = + ttg::getShapePerCTA(cgaLayout.getCTASplitNum(), tensorTy.getShape()); + if (order[0] >= shapePerCTA.size()) + return fallback(); + + int64_t contigDimSizeInBytes = shapePerCTA[order[0]] * elemBitWidth / 8; + unsigned perPhase = 1; + unsigned maxPhase = 1; + if (contigDimSizeInBytes >= 128 && contigDimSizeInBytes % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigDimSizeInBytes >= 64 && contigDimSizeInBytes % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigDimSizeInBytes >= 32 && contigDimSizeInBytes % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else if (contigDimSizeInBytes >= 16 && contigDimSizeInBytes % 16 == 0) { + perPhase = 8; + maxPhase = 1; + } + + unsigned vec = std::max(1u, 128u / static_cast(elemBitWidth)); + return ttg::SwizzledSharedEncodingAttr::get( + tensorTy.getContext(), vec, perPhase, maxPhase, order, cgaLayout); +} + +int getDefUseStageDiff(Operation *op, scf::ForOp forOp, + CoarseSchedule &schedule) { + assert(schedule.count(op) && "Op not found in the schedule"); + int defStage = schedule[op].first; + CoarseSchedule::Cluster defCluster = schedule[op].second; + std::optional useStage; + DenseSet topLevelUsers = + triton::getTopLevelUsersInLoop(op, forOp); + // Special case for loads used by local_alloc: + // we must consider the uses of the local_alloc, as it may be removed and its + // uses will become direct uses of the async load. + // TODO: This is overly conservative, we may need to restrict to cases where + // local_alloc is used by a dot product and has correct encoding. + if (isa(op)) { + DenseSet allocUsers; + for (Operation *topLevelUser : topLevelUsers) { + if (auto localAlloc = dyn_cast(topLevelUser)) { + DenseSet users = + triton::getTopLevelUsersInLoop(localAlloc, forOp); + allocUsers.insert(users.begin(), users.end()); + } + } + topLevelUsers.insert(allocUsers.begin(), allocUsers.end()); + } + DenseSet topLevelWaitUsers; + for (Operation *topLevelUser : topLevelUsers) { + if (isa(topLevelUser)) { + topLevelWaitUsers.insert(topLevelUser); + } + } + for (Operation *topLevelUser : topLevelUsers) { + int _useStage = schedule[topLevelUser].first; + CoarseSchedule::Cluster _useCluster = schedule[topLevelUser].second; + if (*_useCluster > *defCluster) { + // Check if we need extra buffer due to unusual execution order + // The issue occurs when users of the load are scheduled in a later + // cluster, which happens when conditional code gets moved to epilogue + // cluster. This creates a race condition where the local load happens + // after the global-to-local copy for the next pipeline stage starts. + _useStage++; + } + useStage = std::min(_useStage, useStage.value_or(_useStage)); + } + // Waits tells us the buffer is still in use until the wait completes, we + // can't simply load from the buffer and replace the uses of the buffer with + // the load. The stage diff needs to account for the furthest wait. + for (Operation *topLevelUser : topLevelWaitUsers) { + int _useStage = schedule[topLevelUser].first; + useStage = std::max(_useStage, useStage.value_or(_useStage)); + } + if (!useStage) + return 0; + assert(useStage >= defStage && "Op used before defined"); + return useStage.value() - defStage; +} + +void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue, + DominanceInfo &domInfo) { + if (newValue == oldValue) + return; + oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) { + return domInfo.properlyDominates(domOp, use.getOwner()); + }); +} + +///////////////////////////// +// LOWER LOADS +///////////////////////////// + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingTrait sharedEnc, + unsigned distance) { + return triton::createAlloc( + forOp, cast(loadOp->getResultTypes().front()), + loadOp->getLoc(), sharedEnc, distance); +} + +void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, int contiguity, + CoarseSchedule &schedule) { + OpBuilderForStage builder(loadOp.getLoc(), forOp, schedule); + Value zero = arith::ConstantIntOp::create(builder, forOp.getLoc(), 0, 32); + + Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); + assert(firstUse && "LoadOp has no users"); + // Replace the load with async copy, wait and loal_load. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(loadOp); + builder.setStageCluster(schedule[loadOp]); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + ttg::MemDescType allocTy = cast(alloc.getType()); + + // Create async copy + Value view = createSingleBufferView(builder, alloc, insertIdx); + Operation *copy = ttg::AsyncCopyGlobalToLocalOp::create( + builder, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), contiguity); + Operation *commit = + ttg::AsyncCommitGroupOp::create(builder, copy->getResult(0)); + + // Create wait and local load + builder.setStageCluster(schedule[firstUse]); + auto wait = ttg::AsyncWaitOp::create(builder, commit->getResult(0), 0); + auto viewLoad = createSingleBufferView(builder, alloc, extractIdx); + + if (!loadOp.getOther() || isZeroConst(loadOp.getOther())) { + // If masking isn't required, load directly from shared + replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad, + wait.getResult()); + } else if (loadOp->use_begin() != loadOp->use_end()) { + // Otherwise, create a select for non-zero other values as they are not + // handled by AsyncCopyGlobalToLocalOp for now. + auto sharedLoad = ttg::LocalLoadOp::create(builder, loadOp.getType(), + viewLoad, wait.getResult()); + auto select = arith::SelectOp::create( + builder, loadOp.getType(), + // Use the mask operand from the original load, not the one with a + // potentially transformed layout. + loadOp.getMask(), sharedLoad.getResult(), other); + loadOp->replaceAllUsesWith(select->getResults()); + } + schedule.erase(loadOp); + loadOp->erase(); +} + +void createTMAAsyncCopy( + scf::ForOp forOp, Operation *loadOp, Value desc, Value alloc, + Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, + CoarseSchedule &schedule, + function_ref + createCopy) { + OpBuilderForStage builder(loadOp->getLoc(), forOp, schedule); + Value zero = arith::ConstantIntOp::create(builder, forOp.getLoc(), 0, 32); + + Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); + assert(firstUse && "LoadOp has no users"); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + + builder.setInsertionPoint(loadOp); + builder.setStageCluster(schedule[loadOp]); + ttg::MemDescType allocTy = cast(alloc.getType()); + + // Create async copy + Value view = createSingleBufferView(builder, alloc, insertIdx); + + Value pred = arith::ConstantIntOp::create(builder, 1, 1); + createCopy(builder, desc, barrier, view, pred); + + // Create local load after the wait + builder.setInsertionPointAfter(waitOp); + builder.setStageCluster(schedule[firstUse]); + auto viewLoad = createSingleBufferView(builder, alloc, extractIdx); + replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad); + + schedule.erase(loadOp); + loadOp->erase(); +} + +void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp, + Value alloc, Value insertIdx, Value extractIdx, + Value barrier, Operation *waitOp, + CoarseSchedule &schedule) { + return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, + extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value desc, + Value barrier, Value view, Value pred) { + ttng::AsyncTMACopyGlobalToLocalOp::create( + builder, loadOp.getLoc(), desc, + loadOp.getIndices(), barrier, view, pred); + }); +} + +void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, + Value alloc, Value insertIdx, Value extractIdx, + Value barrier, Operation *waitOp, + CoarseSchedule &schedule) { + return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc, + insertIdx, extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value desc, + Value barrier, Value view, Value pred) { + ttng::AsyncTMAGatherOp::create( + builder, gatherOp.getLoc(), desc, + gatherOp.getXOffsets(), gatherOp.getYOffset(), + barrier, view, pred); + }); +} + +struct AsyncLoad { + int stageDiff; + int contiguity = 1; + Value alloc; + Value barrier; + Operation *waitOp; + SharedEncodingTrait sharedEncoding; +}; +struct LoadGroupInfo { + Value insertIdx; + Value extractIdx; + Value phase; + bool hasTMALoad = false; +}; + +// Convert a scalar load to a load of a tensor of shape <1>. +void convertScalarToTensorLoad(Operation *op, CoarseSchedule &schedule, + scf::ForOp forOp) { + auto scalarLoad = cast(op); + Type scalarTy = scalarLoad.getType(); + OpBuilderForStage builder(op->getLoc(), op, schedule); + builder.setInsertionPoint(op); + MLIRContext *ctx = op->getContext(); + auto nWarps = lookupNumWarps(op); + ModuleOp mod = forOp->getParentOfType(); + auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto numCTAs = TritonGPUDialect::getNumCTAs(mod); + auto blockedEnc = + getDefaultBlockedEncoding(ctx, {1}, nWarps, threadsPerWarp, numCTAs); + auto newPtrTy = + RankedTensorType::get({1}, scalarLoad.getPtr().getType(), blockedEnc); + auto newPtr = + tt::SplatOp::create(builder, op->getLoc(), newPtrTy, scalarLoad.getPtr()); + scalarLoad.getPtrMutable().assign(newPtr); + if (scalarLoad.getMask()) { + auto newMaskTy = + RankedTensorType::get({1}, scalarLoad.getMask().getType(), blockedEnc); + auto newMask = tt::SplatOp::create(builder, op->getLoc(), newMaskTy, + scalarLoad.getMask()); + scalarLoad.getMaskMutable().assign(newMask); + } + if (scalarLoad.getOther()) { + auto newOtherTy = + RankedTensorType::get({1}, scalarLoad.getOther().getType(), blockedEnc); + auto newOther = tt::SplatOp::create(builder, op->getLoc(), newOtherTy, + scalarLoad.getOther()); + scalarLoad.getOtherMutable().assign(newOther); + } + auto newDstTy = RankedTensorType::get({1}, scalarLoad.getType(), blockedEnc); + scalarLoad.getResult().setType(newDstTy); + builder.setInsertionPointAfter(op); + Operation *firstUse = getFirstUseOfPipelinedOp({op}, forOp, schedule); + builder.setStageCluster(schedule[firstUse]); + Operation *unsplat = tt::UnsplatOp::create(builder, op->getLoc(), scalarTy, + scalarLoad.getResult()); + scalarLoad.getResult().replaceAllUsesExcept(unsplat->getResult(0), unsplat); +} + +void createTMABarrierAndWait( + scf::ForOp forOp, llvm::MapVector &asyncLoads, + const llvm::MapVector &loadGroups, + CoarseSchedule &schedule) { + SmallVector> commonWaitGroups; + llvm::SmallDenseSet visited; + // Find groups of loads that can share the same barrier. We look consecutive + // loads and check that there are uses in between. + for (auto &[loadOp, asyncLoad] : asyncLoads) { + if (!isTMALoad(loadOp) || visited.count(loadOp)) + continue; + llvm::SmallDenseSet users; + SmallVector group; + Block *loadBlock = loadOp->getBlock(); + auto addToGroup = [&](Operation *loadOp) { + group.push_back(loadOp); + visited.insert(loadOp); + for (Operation *user : loadOp->getUsers()) { + // Special case for MMAv3 loads, we can ignore the alloc and only + // consider uses of the alloc op since it will be removed. + if (!mustLoadToRegisters(loadOp)) { + assert(loadOp->hasOneUse()); + auto alloc = cast(*loadOp->getUsers().begin()); + if (alloc->getBlock() == loadBlock) { + users.insert(alloc->getUsers().begin(), alloc->getUsers().end()); + continue; + } + } + Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user); + if (userInBlock) + users.insert(userInBlock); + } + }; + addToGroup(loadOp); + Operation *nextOp = loadOp->getNextNode(); + int numBuffers = asyncLoad.stageDiff; + while (nextOp) { + if (users.count(nextOp) || visited.count(nextOp)) + break; + if (isTMALoad(nextOp) && asyncLoads.count(nextOp)) { + if (asyncLoads[nextOp].stageDiff != numBuffers) + break; + if (group.size() > 0 && schedule[group[0]] == schedule[nextOp]) { + addToGroup(nextOp); + } + } + nextOp = nextOp->getNextNode(); + } + commonWaitGroups.push_back(group); + } + + // For each group calculate the size and insert the barrier after the last + // load. + for (SmallVector &group : commonWaitGroups) { + int sizeInBytes = 0; + int numBuffers = asyncLoads[group[0]].stageDiff; + const LoadGroupInfo loadGroup = loadGroups.find(numBuffers)->second; + for (Operation *op : group) { + auto tensorTy = cast(op->getResultTypes()[0]); + int loadSize = product(getShapePerCTA(tensorTy)); + sizeInBytes += loadSize * tensorTy.getElementTypeBitWidth() / 8; + } + + Value barrierAlloc = triton::createBarrierAlloc(forOp, numBuffers); + OpBuilderForStage builder(forOp.getLoc(), group[0], schedule); + Value barrier = triton::createSingleBufferView(builder, barrierAlloc, + loadGroup.insertIdx); + Value pred = arith::ConstantIntOp::create(builder, 1, 1); + ttng::BarrierExpectOp::create(builder, barrier, sizeInBytes, pred); + + builder.setInsertionPointAfter(group.back()); + Operation *firstUse = getFirstUseOfPipelinedOp(group, forOp, schedule); + builder.setStageCluster(schedule[firstUse]); + Value barrierViewWait = triton::createSingleBufferView( + builder, barrierAlloc, loadGroup.extractIdx); + auto wait = + ttng::WaitBarrierOp::create(builder, barrierViewWait, loadGroup.phase); + + // Update the async loads info. + for (Operation *op : group) { + asyncLoads[op].barrier = barrier; + asyncLoads[op].waitOp = wait; + } + } +} + +// Check if load requires additional buffer for a mma pipelining +bool loadRequiresAdditionalBuffer(Operation *loadOp) { + auto isMusaTarget = [&]() { + auto module = loadOp->getParentOfType(); + if (!module) + return false; + auto targetAttr = module->getAttrOfType(ttg::AttrTargetName); + return targetAttr && targetAttr.getValue().starts_with("musa:"); + }; + + std::function & out)> + collectNonViewUsers = [&](Operation *op, SmallVector &out) { + for (Operation *user : op->getUsers()) { + if (user->hasTrait()) + collectNonViewUsers(user, out); + else + out.push_back(user); + } + }; + std::function &)> hasDotConsumer = + [&](Operation *op, DenseSet &visited) -> bool { + if (!visited.insert(op).second) + return false; + if (isa(op)) + return true; + for (Operation *user : op->getUsers()) { + if (hasDotConsumer(user, visited)) + return true; + } + return false; + }; + // Pattern match the op sequence used for loading mmav3 operands + if (!mustLoadToRegisters(loadOp)) { + assert(loadOp->hasOneUse()); + ttg::LocalAllocOp alloc = + dyn_cast(*loadOp->getUsers().begin()); + if (alloc) { + SmallVector nonViewUsers; + collectNonViewUsers(alloc, nonViewUsers); + return llvm::any_of(nonViewUsers, [&](Operation *op) { + if (isa(op)) + return true; + if (isMusaTarget()) { + DenseSet visited; + if (hasDotConsumer(op, visited)) + return true; + } + return false; + }); + } + } + return false; +} + +scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + auto module = forOp->getParentOfType(); + auto targetAttr = module + ? module->getAttrOfType(ttg::AttrTargetName) + : StringAttr(); + bool isMusaTarget = targetAttr && targetAttr.getValue().starts_with("musa:"); + auto hasDotConsumer = [&](Operation *loadOp) { + DenseSet visited; + std::function dfs = [&](Operation *op) -> bool { + if (!visited.insert(op).second) + return false; + if (isa(op)) + return true; + for (Operation *user : op->getUsers()) { + if (dfs(user)) + return true; + } + return false; + }; + for (Value result : loadOp->getResults()) { + for (Operation *user : result.getUsers()) { + if (dfs(user)) + return true; + } + } + return false; + }; + + llvm::MapVector asyncLoads; + llvm::MapVector loadGroups; + llvm::SmallVector scalarLoads; + // Only visit the top level ops, we do not support pipelining conditional + // loads for now + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + int stageDiff = getDefUseStageDiff(&op, forOp, schedule); + if (stageDiff == 0) { + // Don't care about non-pipelined loads. Scalar loads will be converted + // to tensor loads if they are pipelined. + continue; + } + SharedEncodingTrait sharedEncoding; + bool canUseAsyncCp = false; + int contiguity = 1; + if (!isa(op.getResultTypes()[0])) { + canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32; + auto numCTAs = lookupNumCTAs(forOp); + sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( + forOp.getContext(), 1, 1, 1, {0}, + ttg::CGAEncodingAttr::get1DLayout(forOp.getContext(), numCTAs)); + if (canUseAsyncCp) { + scalarLoads.push_back(&op); + } + } else { + sharedEncoding = getSharedEncoding(&op); + // Do not create async loads for small loads (cp.async requires at least + // 4 bytes) + canUseAsyncCp = + isa(op) && + canBeConvertedToAsyncLoad(cast(op), axisInfoAnalysis); + int copyVecBytes = getCopyVecBytes( + cast(op.getResultTypes()[0]), sharedEncoding); + + canUseAsyncCp &= copyVecBytes >= 4; + if (canUseAsyncCp) { + auto loadOp = cast(op); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, + axisInfoAnalysis.getMaskAlignment(mask)); + contiguity = vec; + } + } + if (canUseAsyncCp || isTMALoad(&op)) { + if (loadRequiresAdditionalBuffer(&op)) { + // Allocate additional buffer required by the wgmma pipelining. + stageDiff += 1; + } + if (isMusaTarget && isa(op) && hasDotConsumer(&op)) { + // MUSA SQMMA async-copy path requires one extra ring buffer to match + // the expected software-pipeline staging depth. + stageDiff += 1; + sharedEncoding = getMusaSqmmaPipelinedSharedEncoding( + cast(op.getResultTypes()[0])); + } + auto &asyncLoad = asyncLoads[&op]; + asyncLoad.stageDiff = stageDiff; + asyncLoad.contiguity = contiguity; + asyncLoad.sharedEncoding = sharedEncoding; + } else if (stageDiff > 1) { + // Distance-1 loads can in most cases be pipelined in registers without + // any performance degradation, as the schedule will usually reorder the + // user and the producer so there is no liverange overlap, and no copy + // needed. + op.emitRemark() << "Pipelining load that cannot use vectorized " + "copy. This will likely " + "lead to pipelining in registers and severe " + "performance degradation."; + } + } + } + + // Convert scalar loads to be able to use async copy. + for (auto op : scalarLoads) { + convertScalarToTensorLoad(op, schedule, forOp); + } + + if (asyncLoads.empty()) + return forOp; + + for (auto &[loadOp, asyncLoad] : asyncLoads) { + Value alloc = createAlloc(forOp, loadOp, asyncLoad.sharedEncoding, + asyncLoad.stageDiff); + asyncLoad.alloc = alloc; + loadGroups.insert({asyncLoad.stageDiff, {}}); + if (isTMALoad(loadOp)) { + loadGroups[asyncLoad.stageDiff].hasTMALoad = true; + } + } + + IRRewriter builder(forOp); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + // Create a counter to index into the allocations per loop iteration. + // NOTE: We create two duplicates values, insertIdx and extractIdx so that the + // pipeliner will re-materialize the value in later stages of the pipeline + // instead of carrying it as a dependency across multiple iterations. + Value minusOne = arith::ConstantIntOp::create(builder, loc, -1, 32); + Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); + Value one = arith::ConstantIntOp::create(builder, loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + for (auto [_, loadGroup] : loadGroups) { + newOperands.push_back(minusOne); // insertIdx + newOperands.push_back(minusOne); // extractIdx + if (loadGroup.hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. + newOperands.push_back(zero); // phase + } + } + + // Patch the loop to add the new loop carried dependencies. + forOp = addIterArgsToLoop(builder, forOp, newOperands); + + // Update yield op with temporary yield values + auto forYield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + builder.setInsertionPoint(forOp); + loc = forOp.getLoc(); + int argIdx = newOperandIndex; + for (auto &[numBuffers, loadGroup] : loadGroups) { + Value insertIdx = forOp.getBody()->getArgument(argIdx); + argIdx++; + Value extractIdx = forOp.getBody()->getArgument(argIdx); + argIdx++; + Value phase = nullptr; + if (loadGroup.hasTMALoad) { + phase = forOp.getBody()->getArgument(argIdx); + argIdx++; + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + + Value numBuffersVal = + arith::ConstantIntOp::create(builder, loc, numBuffers, 32); + loadGroup.insertIdx = createIncrementModulo(builder, loc, insertIdx, + numBuffersVal, zero, one); + Value cndExt = nullptr; + loadGroup.extractIdx = createIncrementModulo( + builder, loc, extractIdx, numBuffersVal, zero, one, &cndExt); + if (phase) { + Value nextPhase = arith::XOrIOp::create(builder, loc, phase, one); + phase = arith::SelectOp::create(builder, loc, cndExt, nextPhase, phase); + loadGroup.phase = phase; + } + } + + createTMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule); + + bool hasAsyncLoads = false; + for (auto [op, asyncLoad] : asyncLoads) { + auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff]; + if (auto loadOp = dyn_cast(op)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + asyncLoad.contiguity, schedule); + hasAsyncLoads = true; + } else if (auto loadOp = dyn_cast(op)) { + createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + asyncLoad.barrier, asyncLoad.waitOp, schedule); + } else if (auto loadOp = dyn_cast(op)) { + createTMAAsyncGather(forOp, loadOp, asyncLoad.alloc, insertIdx, + extractIdx, asyncLoad.barrier, asyncLoad.waitOp, + schedule); + } + } + // Patch the yield with the updated counters. Subtract to account for the loop + // counter. + argIdx = newOperandIndex - 1; + for (auto &[numBuffers, loadGroup] : loadGroups) { + forYield.setOperand(argIdx++, loadGroup.insertIdx); + forYield.setOperand(argIdx++, loadGroup.extractIdx); + if (loadGroup.phase) + forYield.setOperand(argIdx++, loadGroup.phase); + } + + // Automatically discover dependencies and schedule new insert/extract ops to + // correct stages. + scheduleDependencies(forOp, schedule); + + if (hasAsyncLoads) { + // Insert sync point for any possibly outstanding loads after the loop. This + // can happen as we speculatively execute loads in the loop. + builder.setInsertionPointAfter(forOp); + ttg::AsyncWaitOp::create(builder, loc, ValueRange({}), 0); + } + + // Make sure all ops have attributes. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!schedule.count(&op)) { + op.emitError() << "op not found in the schedule"; + } + assert(schedule.count(&op) && "op not found in the schedule"); + } + return forOp; +} + +///////////////////////////// +// LOWER MMA +///////////////////////////// + +std::pair +getTmemUseStageBoundOps(Value alloc, scf::ForOp forOp, + CoarseSchedule &schedule) { + std::pair bounds = {nullptr, nullptr}; + for (auto user : alloc.getUsers()) { + if (!forOp->isAncestor(user->getParentOp())) { + continue; + } + auto topLevelUser = forOp.getBody()->findAncestorOpInBlock(*user); + if (!bounds.first) { + bounds.first = topLevelUser; + } + if (!bounds.second) { + bounds.second = topLevelUser; + } + if (schedule.isOpBefore(topLevelUser, bounds.first)) { + bounds.first = topLevelUser; + } + if (schedule.isOpBefore(bounds.second, topLevelUser)) { + bounds.second = topLevelUser; + } + } + return bounds; +} + +Operation *hoistBufferOutOfLoop(scf::ForOp forOp, Operation *op, + CoarseSchedule &schedule) { + Operation *newStore = nullptr; + if (!isa(op)) + return nullptr; + // If the alloc is already out of the loop, there is nothing to do. + if (!forOp->isAncestor(op)) + return nullptr; + OpBuilderForStage builder(op->getLoc(), forOp, schedule); + auto allocType = dyn_cast(op->getResult(0).getType()); + auto newType = triton::gpu::MemDescType::get( + allocType.getShape(), allocType.getElementType(), allocType.getEncoding(), + allocType.getMemorySpace(), + /*mutableMemory=*/true); + auto newAlloc = builder.clone(*op); + newAlloc->getResult(0).setType(newType); + builder.setStageCluster(schedule[op]); + if (auto tmemAlloc = dyn_cast(newAlloc)) { + tmemAlloc.getSrcMutable().clear(); + builder.setInsertionPointAfter(op); + Value trueVal = arith::ConstantIntOp::create(builder, 1, 1); + newStore = ttng::TMEMStoreOp::create(builder, tmemAlloc.getResult(), + op->getOperand(0), trueVal); + } else { + auto localAlloc = cast(newAlloc); + localAlloc.getSrcMutable().clear(); + builder.setInsertionPointAfter(op); + newStore = ttg::LocalStoreOp::create(builder, op->getOperand(0), + localAlloc.getResult()); + } + replaceUsesAndPropagateType(builder, op, newAlloc->getResult(0)); + op->erase(); + return newStore; +} + +void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule, + ttng::MMAv5OpInterface mma, int mmaSelfLatency, + Value alloc, int phaseArgIdx, + int barrierIdxArgIdx) { + auto isLoadToBePipelined = [&](Operation *op) { + return schedule[mma].first > schedule[op].first; + }; + + llvm::SmallDenseSet syncCandidates; + + for (auto user : alloc.getUsers()) { + if (auto load = dyn_cast(user)) { + if (load->getBlock() != mma->getBlock()) { + continue; + } + syncCandidates.insert(load); + } + } + + ttng::MMAv5PipelineableOperandsHelper mmaPipeHelper(mma, forOp, + isLoadToBePipelined); + + for (auto def : mmaPipeHelper.unpipelineableOperandDefs) { + auto newStore = hoistBufferOutOfLoop(forOp, def, schedule); + // If the operands are not pipelineable, we need to consider the stores as + // well. + if (!mmaPipeHelper.isPipelineable && + mmaPipeHelper.isOperandsStateDetermined) { + if (newStore) { + syncCandidates.insert(newStore); + } else { + syncCandidates.insert(def); + } + } + } + + // Find the first sync candidate that appears after the MMA + // in the linearized schedule. This is either the first op to appear + // after the MMA or the first op + auto linearizedSchedule = schedule.linearized(forOp, mma); + std::optional latestSyncPoint = linearizedSchedule.findNext( + [&](Operation *op) { return syncCandidates.contains(op); }); + + int mainWaitStage = schedule[mma].first + mmaSelfLatency; + CoarseSchedule::Cluster mainWaitCluster = schedule[mma].second; + if (latestSyncPoint && mmaPipeHelper.isOperandsStateDetermined) { + if (schedule.isOpBefore(*latestSyncPoint, mma)) { + mainWaitStage = schedule[mma].first + 1; + mainWaitCluster = schedule.clusters.newBefore( + schedule.splitClusterBefore(*latestSyncPoint, forOp)); + } else { + mainWaitStage = schedule[*latestSyncPoint].first; + mainWaitCluster = schedule.clusters.newBefore( + schedule.splitClusterBefore(*latestSyncPoint, forOp)); + } + } + + int numStages = mainWaitStage - schedule[mma].first + 1; + + OpBuilderForStage builder(mma.getLoc(), mma, schedule); + Value barrierAlloc = createBarrierAlloc(forOp, numStages); + Value vTrue = arith::ConstantIntOp::create(builder, 1, 1); + Value phase = forOp.getRegionIterArg(phaseArgIdx); + Value zero = arith::ConstantIntOp::create(builder, forOp.getLoc(), 0, 32); + Value barrierIdx; + if (numStages > 1) { + barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx); + } else { + barrierIdx = zero; + } + Value one = arith::ConstantIntOp::create(builder, forOp.getLoc(), 1, 32); + Value numStagesVal = + arith::ConstantIntOp::create(builder, forOp.getLoc(), numStages, 32); + + Value barrierSlice = + triton::createSingleBufferView(builder, barrierAlloc, barrierIdx); + mma.addCompletionBarrier(barrierSlice, vTrue); + mma.setIsAsync(true); + + // List of buffers that may be used until wait completes + SmallVector waitBuffers; + auto mmaAsDotOp = cast(mma.getOperation()); + waitBuffers.push_back(mmaAsDotOp.getA()); + waitBuffers.push_back(mmaAsDotOp.getB()); + if (auto mmaAsScaledDotOp = + dyn_cast(mma.getOperation())) { + waitBuffers.push_back(mmaAsScaledDotOp.getAScale()); + waitBuffers.push_back(mmaAsScaledDotOp.getBScale()); + } + + builder.setInsertionPointAfter(mma); + builder.setStageCluster({mainWaitStage, mainWaitCluster}); + ttng::WaitBarrierOp::create(builder, barrierSlice, phase, waitBuffers); + + // Add waits before loads in conditional blocks + for (auto user : alloc.getUsers()) { + if (auto load = dyn_cast(user)) { + if (load->getBlock() == mma->getBlock()) { + continue; + } + auto topLevelUser = forOp.getBody()->findAncestorOpInBlock(*load); + if (!topLevelUser) { + continue; + } + auto [loadStage, loadCluster] = schedule[topLevelUser]; + if (loadStage < mainWaitStage) { + builder.setStageCluster({loadStage, loadCluster}); + builder.setInsertionPoint(load); + ttng::WaitBarrierOp::create(builder, barrierSlice, phase, waitBuffers); + } + } + } + + builder.setStageCluster(schedule[mma]); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value newPhase = arith::XOrIOp::create(builder, phase, one); + Value newBarrierIdx = barrierIdx; + if (numStages > 1) { + Value barWrap; + newBarrierIdx = createIncrementModulo(builder, builder.getLoc(), barrierIdx, + numStagesVal, zero, one, &barWrap); + newPhase = arith::SelectOp::create(builder, phase.getType(), barWrap, + newPhase, phase); + } + yieldOp->replaceUsesOfWith(phase, newPhase); + yieldOp->replaceUsesOfWith(barrierIdx, newBarrierIdx); +} + +void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule, + ttng::TMEMAllocOp alloc, int bufIdxArgIdx, + int tmemUseNumStages) { + DominanceInfo domInfo(forOp); + Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx); + SmallVector> bufIdxDefs; + auto getCurrBufIdx = [&](Operation *op) { + for (auto [_op, _val] : llvm::reverse(bufIdxDefs)) { + if (domInfo.properlyDominates(_op, op)) { + return _val; + } + } + return Value(); + }; + bufIdxDefs.push_back({&forOp.getBody()->front(), bufIdx}); + + OpBuilderForStage builder(alloc.getLoc(), alloc, schedule); + auto newAlloc = createTMemAlloc(builder, alloc, true, tmemUseNumStages); + Value numStagesVal = + arith::ConstantIntOp::create(builder, tmemUseNumStages, 32); + Value zero = arith::ConstantIntOp::create(builder, 0, 32); + Value one = arith::ConstantIntOp::create(builder, 1, 32); + + bool multibufferingIsValid = false; + + SmallVector allocUsers = + llvm::to_vector(alloc.getResult().getUsers()); + auto auxBuilder = OpBuilder(forOp); + Value replTok = ub::PoisonOp::create(auxBuilder, forOp.getLoc(), + builder.getType()); + if (newAlloc.getToken()) { + newAlloc.getToken().replaceAllUsesWith(replTok); + } + for (auto user : allocUsers) { + if (auto store = dyn_cast(user)) { + store.getDepMutable().clear(); + store.getToken().replaceAllUsesWith(replTok); + if (forOp->isAncestor(store)) { + // We can multibuffer, since the store is a point where we can + // change the buffer index + multibufferingIsValid = true; + builder.setStageCluster(schedule[store]); + builder.setInsertionPoint(store); + // Change the buffer index to the new buffer index on store. + Value curBufIdx = getCurrBufIdx(store); + Value newBufIdx = createIncrementModulo( + builder, forOp.getLoc(), curBufIdx, numStagesVal, zero, one); + if (Value pred = store.getPred()) { + newBufIdx = arith::SelectOp::create(builder, newBufIdx.getType(), + pred, newBufIdx, curBufIdx); + } + replaceAllUsesDominatedBy(store, newBufIdx, curBufIdx, domInfo); + bufIdxDefs.push_back({store, newBufIdx}); + auto tmemSlice = + triton::createSingleBufferView(builder, newAlloc, newBufIdx); + store.getDstMutable().assign(tmemSlice); + } else { + // Store before the loop + assert(store->isBeforeInBlock(forOp) && "Store is not before the loop"); + builder.setInsertionPoint(store); + auto tmemSlice = + triton::createSingleBufferView(builder, newAlloc, zero); + store.getDstMutable().assign(tmemSlice); + } + } else if (auto load = dyn_cast(user)) { + load.getDepMutable().clear(); + load.getToken().replaceAllUsesWith(replTok); + if (forOp->isAncestor(load)) { + builder.setStageCluster(schedule[load]); + builder.setInsertionPoint(load); + Value curBufIdx = getCurrBufIdx(load); + auto tmemSlice = + triton::createSingleBufferView(builder, newAlloc, curBufIdx); + load.getSrcMutable().assign(tmemSlice); + } else { + // Load after the loop + assert(forOp->isBeforeInBlock(load) && "Load is not after the loop"); + builder.setInsertionPoint(load); + auto tmemSlice = triton::createSingleBufferView( + builder, newAlloc, forOp->getResult(bufIdxArgIdx)); + load.getSrcMutable().assign(tmemSlice); + } + } else if (auto mma = dyn_cast(user)) { + mma.getAccDepMutable().clear(); + mma.getToken().replaceAllUsesWith(replTok); + builder.setStageCluster(schedule[mma]); + builder.setInsertionPoint(mma); + // We can legally switch to next buffer index if the mma does not use the + // accumulator + auto isConstTrue = [](Value v) { + if (auto constOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constOp.getValueAttr())) { + return attr.getValue(); + } + } + return false; + }; + multibufferingIsValid = !isConstTrue(mma.useAccumulator()); + Value curBufIdx = getCurrBufIdx(mma.getOperation()); + Value newBufIdx = createIncrementModulo( + builder, forOp.getLoc(), curBufIdx, numStagesVal, zero, one); + newBufIdx = + arith::SelectOp::create(builder, newBufIdx.getType(), + mma.useAccumulator(), curBufIdx, newBufIdx); + replaceAllUsesDominatedBy(mma.getOperation(), newBufIdx, curBufIdx, + domInfo); + bufIdxDefs.push_back({mma.getOperation(), newBufIdx}); + auto tmemSlice = + triton::createSingleBufferView(builder, newAlloc, newBufIdx); + mma.setAccumulator(tmemSlice); + } else { + llvm::errs() << "Unsupported user of the accumulator: " << *user << "\n"; + llvm::report_fatal_error("Unsupported user of the accumulator"); + } + } + if (!multibufferingIsValid) { + llvm::report_fatal_error( + "Trying to multibuffer TMEM while there is no store to the " + "accumulator, and the mma uses the accumulator all the time."); + } + alloc.getToken().replaceAllUsesWith(newAlloc.getToken()); + alloc->erase(); + + Value newBufIdx = bufIdxDefs.back().second; + replaceAllUsesDominatedBy(newBufIdx.getDefiningOp(), newBufIdx, bufIdx, + domInfo); +} + +scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp, + CoarseSchedule &schedule) { + auto isLoadToBePipelined = [&](Operation *op) { + return schedule[mma].first > schedule[op].first; + }; + Value alloc = mma.getAccumulator(); + + int mmaSelfLatency = getSelfLatencyFromAttr(mma.getOperation()); + if (mmaSelfLatency == 0) { + return forOp; + } + + // Create barrier and wait ops + std::pair tmemUseStageBoundOps = + getTmemUseStageBoundOps(alloc, forOp, schedule); + int tmemUseNumStages = schedule[tmemUseStageBoundOps.second].first - + schedule[tmemUseStageBoundOps.first].first; + // If def is in the earlier cluster than the use, we will have a liverange + // overlap and need to add an extra buffer. + if (schedule.isOpInEarlierCluster(tmemUseStageBoundOps.first, + tmemUseStageBoundOps.second) || + (schedule.isOpInSameCluster(tmemUseStageBoundOps.first, + tmemUseStageBoundOps.second) && + tmemUseStageBoundOps.first->isBeforeInBlock( + tmemUseStageBoundOps.second))) { + tmemUseNumStages += 1; + } + + // If the accumulator needs to be double-buffered but we can't find the alloc + // op, then bail out. + if (tmemUseNumStages > 1 && !alloc.getDefiningOp()) + return forOp; + + OpBuilder builder(forOp); + Value minusOne = + arith::ConstantIntOp::create(builder, forOp.getLoc(), -1, 32); + Value zero = arith::ConstantIntOp::create(builder, forOp.getLoc(), 0, 32); + + // Add arguments to the forOp + unsigned newOperandIndex = forOp.getInitArgs().size(); + SmallVector newOperands = { + zero, // phase + zero, // barrierIdx + }; + if (tmemUseNumStages > 1) { + newOperands.push_back(minusOne); // bufIdx + } + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + + int phaseArgIdx = newOperandIndex + 0; + int barrierIdxArgIdx = newOperandIndex + 1; + int bufIdxArgIdx = newOperandIndex + 2; + Value phase = forOp.getRegionIterArg(phaseArgIdx); + Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx); + + SmallVector newYieldOperands = {phase, barrierIdx}; + if (tmemUseNumStages > 1) { + Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx); + newYieldOperands.push_back(bufIdx); + } + cast(forOp.getBody()->getTerminator()) + .getResultsMutable() + .append(newYieldOperands); + + createBarrierAndWaitOps(forOp, schedule, mma, mmaSelfLatency, alloc, + phaseArgIdx, barrierIdxArgIdx); + + if (tmemUseNumStages > 1) { + multibufferTensorMemory(forOp, schedule, + alloc.getDefiningOp(), + bufIdxArgIdx, tmemUseNumStages); + } + + return forOp; +} + +scf::ForOp lowerMMAs(scf::ForOp forOp, CoarseSchedule &schedule) { + SmallVector mmas; + forOp.walk([&](ttng::MMAv5OpInterface mma) { mmas.push_back(mma); }); + for (auto mma : mmas) { + forOp = lowerMMA(mma, forOp, schedule); + } + return forOp; +} + +///////////////////////////// +// LOWER LOOP +///////////////////////////// + +void lowerLoop(scf::ForOp forOp, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp))) { + return; + } + scf::ForOp newForOp = lowerMMAs(forOp, schedule); + newForOp = lowerLoads(newForOp, schedule, axisInfoAnalysis); + newForOp = lowerTMADescriptors(newForOp, schedule); + schedule.serialize(newForOp); +} + +} // namespace + +void lowerLoops(ModuleOp moduleOp) { + triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + if (loops.empty()) + return; + for (auto forOp : loops) { + lowerLoop(forOp, axisInfoAnalysis); + } +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp new file mode 100644 index 0000000000..917bfde3d6 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp @@ -0,0 +1,317 @@ +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "mlir/IR/Dominance.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// MMA Pipeline Analysis +//===----------------------------------------------------------------------===// + +bool triton::nvidia_gpu::areScalesPipelineable(ttng::TCGen5MMAScaledOp scaledOp, + scf::ForOp forOp) { + if (!isa( + scaledOp.getAScale().getType().getEncoding()) && + !forOp.isDefinedOutsideOfLoop(scaledOp.getAScale()) || + !isa( + scaledOp.getBScale().getType().getEncoding()) && + !forOp.isDefinedOutsideOfLoop(scaledOp.getBScale())) { + return false; + } + + return true; +} + +bool ttng::MMAv5PipelineableOperandsHelper::isOperandPipelineable( + Value v, Operation *&foundDef) { + return ttng::isOperandPipelineableBase( + v, forOp, foundDef, [](Operation *) { return false; }, + isLoadToBePipelined); +} + +bool ttng::isOperandPipelineableBase( + Value v, scf::ForOp forOp, Operation *&foundDef, + std::function isPipelineable, + std::function isLoadToBePipelined) { + + if (forOp.isDefinedOutsideOfLoop(v)) { + return true; + } + if (!v.getDefiningOp()) { + return false; + } + while (isa(v.getDefiningOp())) { + v = v.getDefiningOp()->getOperand(0); + } + if (isPipelineable(v.getDefiningOp())) { + return true; + } + if (isa( + v.getDefiningOp())) { + foundDef = v.getDefiningOp(); + return false; + } + auto localAlloc = dyn_cast(v.getDefiningOp()); + if (!localAlloc) { + return false; + } + foundDef = localAlloc; + if (!localAlloc.getSrc()) { + return false; + } + if (forOp.isDefinedOutsideOfLoop(localAlloc.getSrc())) { + return true; + } + auto localAllocSrc = localAlloc.getSrc().getDefiningOp(); + if (!isa_and_nonnull(localAllocSrc)) { + return false; + } + foundDef = localAllocSrc; + if (!isLoadToBePipelined(localAllocSrc)) { + return false; + } + if (canBeAsyncLoad(localAllocSrc)) { + return true; + } + return false; +} + +void ttng::MMAv5PipelineableOperandsHelper::run() { + unpipelineableOperandDefs.clear(); + isOperandsStateDetermined = true; + // Accumulator alloc must be outside the loop. + auto tmemAlloc = mmaOp.getAccumulator().getDefiningOp(); + if (!tmemAlloc) { + return; + } + if (!forOp.isDefinedOutsideOfLoop(tmemAlloc)) { + return; + } + if (auto dotOp = dyn_cast(mmaOp.getOperation())) { + Operation *foundDef = nullptr; + if (!isOperandPipelineable(dotOp.getA(), foundDef)) { + if (foundDef) { + unpipelineableOperandDefs.push_back(foundDef); + } else { + isOperandsStateDetermined = false; + } + } + if (!isOperandPipelineable(dotOp.getB(), foundDef)) { + if (foundDef) { + unpipelineableOperandDefs.push_back(foundDef); + } else { + isOperandsStateDetermined = false; + } + } + } + // For scaled MMA check if the scales are passed through shared memory, and + // also coming from load or outside the loop. + if (auto scaledOp = dyn_cast(mmaOp.getOperation())) { + if (!ttng::areScalesPipelineable(scaledOp, forOp)) { + // Undecidable, we could follow the tmem use-def chain to find the first + // tmem_load. + isOperandsStateDetermined = false; + return; + } + Operation *foundDef = nullptr; + if (!isOperandPipelineable(scaledOp.getAScale(), foundDef)) { + if (foundDef) { + unpipelineableOperandDefs.push_back(foundDef); + } else { + isOperandsStateDetermined = false; + } + } + if (!isOperandPipelineable(scaledOp.getBScale(), foundDef)) { + if (foundDef) { + unpipelineableOperandDefs.push_back(foundDef); + } else { + isOperandsStateDetermined = false; + } + } + } + isPipelineable = + isOperandsStateDetermined && unpipelineableOperandDefs.empty(); +} + +bool ttng::hasAccReadModifyWrite(ttng::MMAv5OpInterface mma, scf::ForOp forOp) { + auto tmemAlloc = mma.getAccumulator().getDefiningOp(); + if (!tmemAlloc || !forOp.isDefinedOutsideOfLoop(tmemAlloc)) { + // Alloc not hoisted, or IR is not canonicalized. Pessimistically assume + // the accumulator is read-modify-written. + return true; + } + SmallVector stores; + SmallVector loads; + for (auto user : tmemAlloc->getUsers()) { + if (isa(user) && + forOp->isAncestor(user->getParentOp())) { + stores.push_back(cast(user)); + } + if (isa(user) && forOp->isAncestor(user->getParentOp())) { + loads.push_back(cast(user)); + } + } + if (stores.empty() || loads.empty()) { + return false; + } + SmallVector readValues; + DenseSet seen; + llvm::SetVector modifiedValues; + for (auto load : loads) { + readValues.push_back(load->getResult(0)); + } + while (!readValues.empty()) { + Value v = readValues.pop_back_val(); + if (!seen.insert(v).second) { + continue; + } + for (auto &use : v.getUses()) { + if (llvm::is_contained(stores, use.getOwner())) { + continue; // R-W, not midified, this is safe + } + if (auto yieldOp = dyn_cast(use.getOwner())) { + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + readValues.push_back(ifOp.getResult(use.getOperandNumber())); + } + if (forOp == yieldOp->getParentOp()) { + readValues.push_back(forOp.getRegionIterArg(use.getOperandNumber())); + } + } else { + modifiedValues.insert(use.getOwner()->getResults().begin(), + use.getOwner()->getResults().end()); + } + } + } + while (!modifiedValues.empty()) { + Value v = modifiedValues.pop_back_val(); + if (!seen.insert(v).second) { + continue; + } + for (auto &use : v.getUses()) { + if (llvm::is_contained(stores, use.getOwner())) { + return true; // RMW! + } + if (auto yieldOp = dyn_cast(use.getOwner())) { + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + modifiedValues.insert(ifOp.getResult(use.getOperandNumber())); + } + if (forOp == yieldOp->getParentOp()) { + modifiedValues.insert(forOp.getRegionIterArg(use.getOperandNumber())); + } + } else { + modifiedValues.insert(use.getOwner()->getResults().begin(), + use.getOwner()->getResults().end()); + } + } + } + return false; +} + +static bool accUseFlagSetToFalse(ttng::MMAv5OpInterface mma, scf::ForOp forOp) { + Value accUseFlag = mma.useAccumulator(); + if (matchPattern(accUseFlag, m_Zero())) { + return true; + } + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value accUseFlagInit; + while (auto blockArg = dyn_cast(accUseFlag)) { + accUseFlag = yieldOp.getOperand(blockArg.getArgNumber() - 1); + accUseFlagInit = forOp.getInitArgs()[blockArg.getArgNumber() - 1]; + } + + if (accUseFlagInit && matchPattern(accUseFlagInit, m_Zero()) && + matchPattern(accUseFlag, m_One())) { + // A simple case for nested loops - the use flag is initialized to false + // and uncondionally set to true in later iterations + return true; + } + + // If the accUseFlag is overwritten in the loop, we treat it as a 'false' + // with condition being ~accUseFlag. + return accUseFlag.getDefiningOp() && + forOp->isAncestor(accUseFlag.getDefiningOp()); +} + +static bool accOverwrittenInLoop(ttng::MMAv5OpInterface mma, scf::ForOp forOp) { + auto tmemAlloc = mma.getAccumulator().getDefiningOp(); + if (!tmemAlloc || !forOp.isDefinedOutsideOfLoop(tmemAlloc)) { + return false; + } + for (auto user : tmemAlloc->getUsers()) { + if (isa(user) && + forOp->isAncestor(user->getParentOp())) { + return true; + } + } + return false; +} + +bool ttng::isAccMultibufferingPossible(ttng::MMAv5OpInterface mma, + scf::ForOp forOp) { + // If the accumulator is never overwritten in the loop, we can't multibuffer + // it, as the overwrite point is the only place where we can swap the + // buffer. + return accUseFlagSetToFalse(mma, forOp) || accOverwrittenInLoop(mma, forOp); +} + +bool ttng::requiresAccMultiBuffering(ttng::MMAv5OpInterface mma, + scf::ForOp forOp) { + auto tmemAlloc = mma.getAccumulator().getDefiningOp(); + if (!tmemAlloc || !forOp.isDefinedOutsideOfLoop(tmemAlloc)) { + return true; // Pessimistically assume the accumulator requires + // multi-buffering. + } + // If the accumulator is being read in the loop, we will need to multibuffer + // when pipelining. + for (auto user : tmemAlloc->getUsers()) { + if (isa(user) && forOp->isAncestor(user->getParentOp())) { + return true; + } + } + return false; +} + +bool ttng::hasLoadsAfterMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp) { + auto tmemAlloc = mma.getAccumulator().getDefiningOp(); + if (!tmemAlloc || !forOp.isDefinedOutsideOfLoop(tmemAlloc)) { + return false; + } + for (auto user : tmemAlloc->getUsers()) { + if (isa(user)) { + auto ancestorOp = forOp.getBody()->findAncestorOpInBlock(*user); + if (ancestorOp && mma->isBeforeInBlock(ancestorOp)) { + return true; + } + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// MMA Pipeline Rewriters +//===----------------------------------------------------------------------===// + +ttng::TMEMAllocOp ttng::createTMemAlloc(OpBuilder &builder, + ttng::TMEMAllocOp oldTMemAllocOp, + bool multiBufferred, int numStages) { + Location loc = oldTMemAllocOp.getLoc(); + auto oldRetType = oldTMemAllocOp.getType(); + SmallVector shape = {oldRetType.getShape().begin(), + oldRetType.getShape().end()}; + if (multiBufferred) { + shape.insert(shape.begin(), numStages); + } + Type accMemDescType = triton::gpu::MemDescType::get( + shape, oldRetType.getElementType(), oldRetType.getEncoding(), + oldRetType.getMemorySpace(), /*mutableMemory=*/true); + return ttng::TMEMAllocOp::create( + builder, oldTMemAllocOp.getLoc(), accMemDescType, + builder.getType(), /*src=*/Value()); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 0000000000..c7034e4183 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,869 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" + +// FIXME: PipelineExpander should not depend on Triton-specific headers! +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + triton::PipeliningOption::EmitPredicateStageFnType emitPredicateStageFn = + nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + + /// Return true if the schedule is possible and return false otherwise. A + /// schedule is correct if all definitions are scheduled before uses. + bool verifySchedule(); + +public: + /// Initialize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + LogicalResult emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + LogicalResult emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); +}; + +/// Find operands of all the nested operations within `op`. +static SetVector getNestedOperands(Operation *op) { + SetVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + operands.insert(operand); + } + }); + return operands; +} + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); + if (numIteration >= maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + emitPredicateStageFn = options.emitPredicateStageFn; + if (emitPredicateStageFn == nullptr) { + emitPredicateStageFn = mlir::triton::emitPredicateForStage; + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + if (!verifySchedule()) { + LDBG("--invalid schedule: " << op << " -> BAIL"); + return false; + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Support only loop-carried dependencies with a distance of one iteration or + // those defined outside of the loop. This means that any dependency within a + // loop should either be on the immediately preceding iteration, the current + // iteration, or on variables whose values are set before entering the loop. + for (auto &op : forOp.getBody()->without_terminator()) { + for (auto operand : getNestedOperands(&op)) { + auto [def, distance] = getDefiningOpAndDistance(operand); + if (!def) + continue; + if (distance > 1) { + LDBG("--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"); + return false; + } + } + } + annotateFn = options.annotateFn; + return true; +} + +/// Compute unrolled cycles of each op (consumer) and verify that each op is +/// scheduled after its operands (producers) while adjusting for the distance +/// between producer and consumer. +bool LoopPipelinerInternal::verifySchedule() { + int64_t numCylesPerIter = opOrder.size(); + // Pre-compute the unrolled cycle of each op. + DenseMap unrolledCyles; + for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { + Operation *def = opOrder[cycle]; + auto it = stages.find(def); + assert(it != stages.end()); + int64_t stage = it->second; + unrolledCyles[def] = cycle + stage * numCylesPerIter; + } + for (Operation *consumer : opOrder) { + int64_t consumerCycle = unrolledCyles[consumer]; + for (Value operand : getNestedOperands(consumer)) { + auto [producer, distance] = getDefiningOpAndDistance(operand); + if (!producer) + continue; + auto it = unrolledCyles.find(producer); + // Skip producer coming from outside the loop. + if (it == unrolledCyles.end()) + continue; + int64_t producerCycle = it->second; + if (consumerCycle < producerCycle - numCylesPerIter * distance) { + InFlightDiagnostic diag = + consumer->emitWarning("operation scheduled before its operands. " + "Pipelining will be disabled."); + diag.attachNote(producer->getLoc()) + .append("operand defined here: ") + .appendOp(*producer, OpPrintingFlags().printGenericOpForm()); + return false; + } + } + } + return true; +} + +/// Clone `op` and call `callback` on the cloned op's operands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + clone->walk([&](Operation *nested) { + // 'clone' itself will be visited first. + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { + setValueMapping(arg, operand.get(), 0); + } + + // If the incoming value to an iter arg from the loop yield is defined outside + // the loop, then that means the iter arg takes that value for all stages + // after the first stage. + auto yield = cast(forOp.getBody()->getTerminator()); + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), yield->getOpOperands())) { + if (forOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) + continue; + for (int64_t i = 1; i < maxStage; ++i) + setValueMapping(arg, operand.get(), i); + } + + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); + for (int64_t i = 0; i < maxStage; i++) { + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = arith::AddIOp::create( + rewriter, loc, lb, + arith::MulIOp::create( + rewriter, loc, step, + arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + + if (dynamicLoop) { + // pred = ub > lb + (i * step) + predicates[i] = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, iv, ub); + } + + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + if (newOp == nullptr) + return failure(); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + Value source = newOp->getResult(destId); + // If the value is a loop carried dependency update the loop argument + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + if (predicates[predicateIdx] && + !forOp.getResult(operand.getOperandNumber()).use_empty()) { + // If the value is used outside the loop, we need to make sure we + // return the correct version of it. + Value prevValue = valueMapping + [forOp.getRegionIterArgs()[operand.getOperandNumber()]] + [i - stages[op]]; + source = arith::SelectOp::create( + rewriter, loc, predicates[predicateIdx], source, prevValue); + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + source, i - stages[op] + 1); + } + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + } + } + } + return success(); +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + return triton::getDefiningOpAndDistance(forOp, value); +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value valueVersion = + valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage->second]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } else + newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + arith::MulIOp::create(rewriter, loc, step, maxStageValue); + newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep); + } + auto newForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + newForOp->setAttrs(forOp->getAttrs()); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + predicates[i] = emitPredicateStageFn(rewriter, newForOp.getInductionVar(), + ub, step, maxStage, i); + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = arith::MulIOp::create( + rewriter, forOp.getLoc(), step, + arith::ConstantOp::create( + rewriter, forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(), + newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + if (forOp.isDefinedOutsideOfLoop(ret)) { + // Special case for values defined outside the loop accessed with + // distance 1. + if (useStage != maxStage) { + nestedNewOp->setOperand(operand->getOperandNumber(), ret); + } + continue; + } + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = arith::SelectOp::create( + rewriter, pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + auto defStage = stages.find(def); + if (defStage == stages.end()) { + for (unsigned int stage = 1; stage <= maxStage; stage++) + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + retVal.value(), stage); + } else if (defStage->second > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage->second + 1); + } + } + scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands); + return success(); +} + +LogicalResult +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + + auto createConst = [&](int v) { + return arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(t, v)); + }; + + // total_iterations = cdiv(range_diff, step); + // - range_diff = ub - lb + // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step + Value zero = createConst(0); + Value one = createConst(1); + Value stepLessZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, step, zero); + Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one, + createConst(-1)); + + Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb); + Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step); + Value rangeDecr = + arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr); + Value totalIterations = + arith::DivSIOp::create(rewriter, loc, rangeDecr, step); + + // If total_iters < max_stage, start the epilogue at zero to match the + // ramp-up in the prologue. + // start_iter = max(0, total_iters - max_stage) + Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations, + createConst(maxStage)); + iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI); + + // Capture predicates for dynamic loops. + SmallVector predicates(maxStage + 1); + + for (int64_t i = 1; i <= maxStage; i++) { + // newLastIter = lb + step * iterI + Value newlastIter = arith::AddIOp::create( + rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI)); + + setValueMapping(forOp.getInductionVar(), newlastIter, i); + + // increment to next iterI + iterI = arith::AddIOp::create(rewriter, loc, iterI, one); + + if (dynamicLoop) { + // Disable stages when `i` is greater than total_iters. + // pred = total_iters >= i + predicates[i] = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, + totalIterations, createConst(i)); + } + } + + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + SmallVector> returnMap(returnValues.size()); + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + unsigned currentVersion = maxStage - stages[op] + i; + unsigned nextVersion = currentVersion + 1; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[currentVersion]; + newOperand->set(replacement); + } + }); + if (dynamicLoop) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); + if (!newOp) + return failure(); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (auto [opRes, newRes] : + llvm::zip(op->getResults(), newOp->getResults())) { + setValueMapping(opRes, newRes, currentVersion); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != opRes) + continue; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + unsigned ri = operand.getOperandNumber(); + returnValues[ri] = newRes; + Value mapVal = forOp.getRegionIterArgs()[ri]; + returnMap[ri] = std::make_pair(mapVal, currentVersion); + if (nextVersion <= maxStage) + setValueMapping(mapVal, newRes, nextVersion); + } + } + } + if (dynamicLoop) { + // Select return values from this stage (live outs) based on predication. + // If the stage is valid select the peeled value, else use previous stage + // value. + for (auto pair : llvm::enumerate(returnValues)) { + unsigned ri = pair.index(); + auto [mapVal, currentVersion] = returnMap[ri]; + if (mapVal) { + unsigned nextVersion = currentVersion + 1; + Value pred = predicates[currentVersion]; + Value prevValue = valueMapping[mapVal][currentVersion]; + auto selOp = arith::SelectOp::create(rewriter, loc, pred, + pair.value(), prevValue); + returnValues[ri] = selOp; + if (nextVersion <= maxStage) + setValueMapping(mapVal, selOp, nextVersion); + } + } + } + } + return success(); +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + if (failed(pipeliner.emitPrologue(rewriter))) + return failure(); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) + return failure(); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} + +Value mlir::triton::emitPredicateForStage(RewriterBase &rewriter, + Value inductionVar, Value upperBound, + Value step, uint64_t maxStage, + uint64_t stage) { + auto loc = inductionVar.getLoc(); + auto type = inductionVar.getType(); + Value c = arith::SubIOp::create( + rewriter, loc, upperBound, + arith::MulIOp::create( + rewriter, loc, step, + arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(type, maxStage - stage)))); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + inductionVar, c); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp new file mode 100644 index 0000000000..b8170bfad8 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -0,0 +1,910 @@ +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// Hoisting Utilities +//===----------------------------------------------------------------------===// + +bool triton::isPureScalarOp(Operation *op) { + auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); }; + return isPure(op) && llvm::all_of(op->getOperandTypes(), isScalar) && + llvm::all_of(op->getResultTypes(), isScalar); +} + +bool triton::getDominatingValueSetOpsToHoist( + DominanceInfo &domInfo, Operation *refOp, ArrayRef valueSet, + llvm::SetVector &toHoist, + function_ref canHoist, + function_ref canUseArg) { + // The set of operations below `refOp` that are being checked if they can be + // hoisted. This set prevents checking operations twice but also if the + // computation can be hoisted, this becomes the set of operations to hoist. + llvm::SetVector visited; + + // Climb the use-def chain breadth-first so that operations can be hoisted in + // the reverse visitation order. + std::queue queue; + for (Value value : valueSet) + queue.push(value); + + while (!queue.empty()) { + Value value = queue.front(); + queue.pop(); + + // If the value properly dominates the outer loop, then it must be invariant + // to it. + if (domInfo.properlyDominates(value, refOp)) + continue; + // If the value is a block argument, check if it can be used. + if (auto arg = dyn_cast(value)) { + if (!canUseArg(arg)) + return false; + continue; + } + + Operation *op = value.getDefiningOp(); + // Check if the op was already visited. + if (visited.contains(op)) + continue; + // If the defining op cannot be hoisted, then the value cannot be made loop + // invariant. + if (!canHoist(op)) + return false; + visited.insert(op); + // Recurse on the operands of the op. + for (Value operand : op->getOperands()) + queue.push(operand); + } + + // The operations in `visited` must be hoisted. Note that operations are not + // added to `toHoist` unless all of `values` can be hoisted. This is to avoid + // hoisting operations for loops that don't end up getting fused if one of + // their bounds operands cannot be hoisted. + toHoist.insert(visited.begin(), visited.end()); + + return true; +} + +void triton::hoistOpsBefore(Operation *refOp, + const llvm::SetVector &toHoist) { + return hoistOpsBefore(refOp->getBlock(), refOp->getIterator(), toHoist); +} +void triton::hoistOpsBefore(Block *block, Block::iterator it, + const llvm::SetVector &toHoist) { + for (Operation *op : topologicalSort(toHoist)) { + op->moveBefore(block, it); + } +} + +//===----------------------------------------------------------------------===// +// Sinking Utilities +//===----------------------------------------------------------------------===// + +Value triton::sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out, + Block *block) { + OpBuilder::InsertionGuard guard(rewriter); + for (; block != in.getParentBlock(); + block = block->getParentOp()->getBlock()) { + Operation *op = block->getParentOp(); + rewriter.setInsertionPoint(op); + + // `in` is live into the loop body. `out` becomes the live-out if the + // loop executes at least once. + if (auto forOp = dyn_cast(op)) { + forOp = addIterArgsToLoop(rewriter, forOp, in); + appendToForOpYield(forOp, out); + out = forOp.getResults().back(); + continue; + } + + // `in` is live into both branches. `out` becomes the live-out if the + // particular branch is taken. + if (auto ifOp = dyn_cast(op)) { + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, out.getType()); + scf::YieldOp taken = newIfOp.thenYield(); + scf::YieldOp other = newIfOp.elseYield(); + if (block == newIfOp.elseBlock()) + std::swap(taken, other); + taken->insertOperands(taken.getNumOperands(), out); + other->insertOperands(other.getNumOperands(), in); + out = newIfOp.getResults().back(); + rewriter.eraseOp(ifOp); + continue; + } + + // TODO: Handle `scf.while`, etc. + llvm::report_fatal_error("FIXME: sinking into unhandled control flow op: " + + op->getName().getStringRef()); + } + + return out; +} + +//===----------------------------------------------------------------------===// +// Loop Pipelining Utilities +//===----------------------------------------------------------------------===// + +bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + }); +} + +bool mlir::triton::isOuterLoop(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) { + return isa(op); + }); +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isConstantIntValue(pred, 1)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (op->hasTrait()) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto gatherOp = dyn_cast(op)) { + rewriter.setInsertionPoint(gatherOp); + Value mask = getPredMask(rewriter, gatherOp.getPred().getType(), + gatherOp.getPred(), pred); + gatherOp.getPredMutable().assign(mask); + return op; + } + if (auto expectOp = dyn_cast(op)) { + rewriter.setInsertionPoint(expectOp); + Value mask = getPredMask(rewriter, expectOp.getPred().getType(), + expectOp.getPred(), pred); + expectOp.getPredMutable().assign(mask); + return op; + } + if (auto mmav5Op = dyn_cast(op)) { + rewriter.setInsertionPoint(mmav5Op); + auto currPred = mmav5Op.getPredicate(); + Value mask = getPredMask(rewriter, currPred.getType(), currPred, pred); + mmav5Op.setPredicate(mask); + return op; + } + if (auto tmemStoreOp = dyn_cast(op)) { + rewriter.setInsertionPoint(tmemStoreOp); + Value mask = getPredMask(rewriter, tmemStoreOp.getPred().getType(), + tmemStoreOp.getPred(), pred); + tmemStoreOp.getPredMutable().assign(mask); + return op; + } + if (auto waitBarrier = dyn_cast(op)) { + rewriter.setInsertionPoint(waitBarrier); + Value mask = pred; + Value currentPred = waitBarrier.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + waitBarrier.getPredMutable().assign(mask); + return op; + } + if (auto arriveBarrier = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveBarrier); + Value mask = pred; + Value currentPred = arriveBarrier.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + arriveBarrier.getPredMutable().assign(mask); + return op; + } + if (auto commit = dyn_cast(op)) { + rewriter.setInsertionPoint(commit); + Value mask = pred; + Value currentPred = commit.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + commit.getPredMutable().assign(mask); + return op; + } + if (auto storeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(storeOp); + Value mask = getPredMask(rewriter, storeOp.getPtr().getType(), + storeOp.getMask(), pred); + storeOp.getMaskMutable().assign(mask); + return op; + } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } + if (!op->isRegistered()) { + // Skip ops from unregistered dialects to make writing lit tests easier. + return op; + } + + op->emitOpError("pipeliner doesn't know how to predicate this op."); + llvm::report_fatal_error("Fatal pipeliner error"); + return op; +} + +Operation *mlir::triton::wrapInMaskOp(RewriterBase &rewriter, Operation *op, + Value pred) { + auto mask = + ttg::MaskOp::create(rewriter, op->getLoc(), op->getResultTypes(), pred); + rewriter.createBlock(&mask->getRegion(0)); + rewriter.setInsertionPointToStart(&mask->getRegion(0).front()); + auto newOp = rewriter.clone(*op); + ttg::MaskReturnOp::create(rewriter, op->getLoc(), newOp->getResults()); + op->replaceAllUsesWith(mask->getResults()); + rewriter.eraseOp(op); + return mask; +} + +void mlir::triton::resolveMaskOp(ModuleOp moduleOp) { + IRRewriter rewriter(moduleOp); + + // Canonicalize the IR to simplify the arithmetic ops defining the mask + auto arithDialect = + moduleOp.getContext()->getLoadedDialect(); + RewritePatternSet patterns(moduleOp.getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (mlir::applyPatternsGreedily(moduleOp, std::move(patterns)).failed()) + return llvm::report_fatal_error("Failed to canonicalize the IR"); + + SmallVector maskOps; + moduleOp->walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); }); + for (auto maskOp : maskOps) { + rewriter.setInsertionPoint(maskOp); + while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) { + Operation *op = &maskOp.getBody()->front(); + rewriter.moveOpBefore(op, maskOp); + op = triton::predicateOp(rewriter, op, maskOp.getPred()); + } + maskOp->replaceAllUsesWith( + maskOp.getBody()->getTerminator()->getOperands()); + maskOp->erase(); + } +} + +// Return true if the given ForOp has the attribute +// `tt.disallow_acc_multi_buffer` set to true. +bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) { + return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName); +} + +std::pair +mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) { + int64_t distance = 0; + DenseSet seen; + while (auto arg = dyn_cast(value)) { + // Ignore implicit captures. + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + ++distance; + value = forOp.getYieldedValues()[arg.getArgNumber() - 1]; + if (!seen.insert(value).second) + return {nullptr, 0}; + } + return {cast(value), distance}; +} + +std::pair +mlir::triton::getDefiningOpAndDistance(scf::ForOp forOp, Value value) { + auto [definition, distance] = getDefinitionAndDistance(forOp, value); + return {definition ? definition.getDefiningOp() : nullptr, distance}; +} + +int mlir::triton::getCopyVecBytes(RankedTensorType registerTy, + ttg::SharedEncodingTrait sharedEnc) { + auto shape = registerTy.getShape(); + auto regLayout = triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); + // FIXME: Here we should pass a MemDescType instead of a SharedEncodingTrait!! + // This is currently broken for memdesc_subslice! + auto sharedLayout = triton::gpu::toLinearLayout(shape, sharedEnc); + auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + const int vecElems = regToSharedLayout.getNumConsecutiveInOut(); + return vecElems * registerTy.getElementTypeBitWidth() / 8; +} + +bool mlir::triton::canBeConvertedToAsyncLoad( + tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + unsigned width = 0; + if (tensorTy) { + auto ty = cast(tensorTy.getElementType()).getPointeeType(); + width = vec * ty.getIntOrFloatBitWidth(); + } else { + width = cast(ptr.getType()) + .getPointeeType() + .getIntOrFloatBitWidth(); + } + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + return width >= 32; +} + +void mlir::triton::serializeLatencies(ModuleOp module, + DenseMap &opLatency) { + auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper(); + auto builder = Builder(module); + for (auto &[op, latency] : opLatency) { + helper.setAttr(op, builder.getI32IntegerAttr(latency)); + } +} + +void mlir::triton::serializeSelfLatencies( + ModuleOp module, DenseMap &opSelfLatency) { + auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper(); + auto builder = Builder(module); + for (auto &[op, latency] : opSelfLatency) { + helper.setAttr(op, builder.getI32IntegerAttr(latency)); + } +} + +DenseMap mlir::triton::deserializeLatencies(Operation *op) { + DenseMap opLatency; + auto latencyHelper = TritonDialect::getLoaded(op)->getLatencyAttrHelper(); + op->walk([&](Operation *op) { + if (auto attr = latencyHelper.getAttr(op)) { + opLatency[op] = attr.getInt(); + latencyHelper.removeAttr(op); + } + }); + return opLatency; +} + +Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type, + unsigned numBuffers) { + MLIRContext *ctx = rewriter.getContext(); + unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs( + rewriter.getBlock()->getParentOp()->getParentOfType()); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(rewriter.getContext()); + auto barrierCGALayout = ttg::CGAEncodingAttr::get1DLayout(ctx, numCTAs); + auto barrierEncoding = + ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCGALayout); + ttg::MemDescType memDescType = ttg::MemDescType::get( + {numBuffers, numCTAs}, type, barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + return ttg::LocalAllocOp::create(rewriter, memDescType, Value()); +} + +// Create an allocation and init the mbarriers. +Value mlir::triton::createBarrierAlloc(Operation *op, int numBarriers, + int arriveCount) { + ImplicitLocOpBuilder rewriter(op->getLoc(), op); + + Value barrierAlloc = + createScalarAlloc(rewriter, rewriter.getI64Type(), numBarriers); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i); + ttng::InitBarrierOp::create(rewriter, barrierView, arriveCount); + } + // Invalidate and deallocate the barriers. + rewriter.setInsertionPointAfter(op); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i); + ttng::InvalBarrierOp::create(rewriter, barrierView); + } + ttg::LocalDeallocOp::create(rewriter, barrierAlloc); + return barrierAlloc; +} + +Value mlir::triton::createAlloc(Operation *insertBefore, RankedTensorType ty, + Location loc, + gpu::SharedEncodingTrait sharedEnc, + unsigned distance) { + OpBuilder builder(insertBefore); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(insertBefore->getContext()); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + Value alloc = ttg::LocalAllocOp::create(builder, loc, memdescType); + + builder.setInsertionPointAfter(insertBefore); + ttg::LocalDeallocOp::create(builder, insertBefore->getLoc(), alloc); + return alloc; +} + +bool mlir::triton::isTMALoad(Operation *op) { + return isa(op); +} + +bool mlir::triton::canBeAsyncLoad(Operation *op) { + if (mlir::triton::isTMALoad(op)) { + return true; + } + assert(isa(op)); + ttg::SharedEncodingTrait sharedEncoding = mlir::triton::getSharedEncoding(op); + // Do not create async loads for small loads (cp.async requires at least 4 + // bytes) + int copyVecBytes = mlir::triton::getCopyVecBytes( + cast(op->getResultTypes()[0]), sharedEncoding); + if (copyVecBytes >= 4) { + return true; + } + return false; +} + +void mlir::triton::combineRedundantWaitOps( + llvm::SmallSetVector &waitOps) { + llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; + SmallVector waitGroup = {waitOp}; + SmallVector depTokens = waitOp.getOperands(); + unsigned minWaitNumber = waitOp.getNum(); + Operation *next = waitOp->getNextNode(); + // Stop if we reach the end of the block or if there is another commit group + // or a branching op (forOp, ifOp, whileOp) in between the waits + while (next && + !isa(next)) { + if (auto nextWait = dyn_cast(next)) { + waitGroup.push_back(nextWait); + minWaitNumber = std::min(minWaitNumber, nextWait.getNum()); + depTokens.append(nextWait.getOperands().begin(), + nextWait.getOperands().end()); + } + next = next->getNextNode(); + } + if (waitGroup.size() == 1) + continue; + OpBuilder builder(waitGroup.front()); + auto newWaitOp = ttg::AsyncWaitOp::create(builder, waitOp.getLoc(), + depTokens, minWaitNumber); + for (auto waitOp : waitGroup) { + toDelete[waitOp] = newWaitOp; + } + } + for (auto waitOp : toDelete) { + waitOp.first->replaceAllUsesWith(waitOp.second); + waitOp.first->erase(); + } +} + +ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy, + bool mutableMemory) { + return ttg::MemDescType::get(allocTy.getShape().drop_front(), + allocTy.getElementType(), allocTy.getEncoding(), + allocTy.getMemorySpace(), mutableMemory, + /*allocShape=*/allocTy.getAllocShape()); +} + +ttg::MemDescType +mlir::triton::getMultiBufferedType(ttg::MemDescType memDescType, + int32_t depth) { + auto shape = memDescType.getShape(); + SmallVector bufferShape(shape.begin(), shape.end()); + bufferShape.insert(bufferShape.begin(), depth); + return ttg::MemDescType::get( + bufferShape, memDescType.getElementType(), memDescType.getEncoding(), + memDescType.getMemorySpace(), /*mutableMemory*/ true); +} + +ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) { + auto cgaLayout = ttg::getCGALayout(ty.getEncoding()); + auto order = ttg::getOrder(ty); + // Use generic layout. This won't be optimal for 2D tensors. + return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + cgaLayout); +} + +ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) { + // Try to use local alloc encoding if possible. + ttg::SharedEncodingTrait localAllocEnc; + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) { + // Some users have different encoding than others. + // Use one of the encodings, and warn about the performance issue. + op->emitRemark() + << "Pipelining load with different use encodings. This will lead " + "to layout conversions and performance degradation."; + continue; + } + } + } + + auto ty = cast(op->getResultTypes()[0]); + auto cgaLayout = ttg::getCGALayout(ty.getEncoding()); + auto order = ttg::getOrder(ty); + if (isTMALoad(op)) { + // TMA encoding is set on the descriptor type + TypedValue desc; + if (auto load = dyn_cast(op)) { + desc = load.getDesc(); + } else if (auto gather = dyn_cast(op)) { + desc = gather.getDesc(); + } else { + op->emitError() << "unrecognized tma load type"; + llvm::report_fatal_error("unrecognized tma load type"); + } + return ttng::getEncodingFromDescriptor(op, ty, desc); + } + + if (localAllocEnc) + return localAllocEnc; + + // Try to use dot encoding if possible. + bool incompatible = false; + localAllocEnc = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + + if (localAllocEnc) + return localAllocEnc; + + // Use generic layout. This won't be optimal for 2D tensors. + return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + cgaLayout); +} + +int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp, + int defaultNumStages) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + auto helper = TritonDialect::getLoaded(forOp)->getNumStagesAttrHelper(); + if (auto attr = helper.getAttr(forOp)) + return attr.getInt(); + return defaultNumStages; +} + +TypedValue +triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) { + assert(isa(alloc.getType()) && "Expected MemDescType"); + auto allocDescType = cast(alloc.getType()); + SmallVector shape; + assert(allocDescType.getShape().size() > 1 && + "Expected multi-dimensional memdesc (e.g., Nx...) for subview"); + shape.insert(shape.end(), allocDescType.getShape().begin() + 1, + allocDescType.getShape().end()); + auto viewDescType = ttg::MemDescType::get( + shape, allocDescType.getElementType(), allocDescType.getEncoding(), + allocDescType.getMemorySpace(), allocDescType.getMutableMemory()); + return ttg::MemDescIndexOp::create(builder, alloc.getLoc(), viewDescType, + alloc, idx); +} + +TypedValue +triton::createSingleBufferView(OpBuilder &builder, Value alloc, int idx) { + Value idxVal = arith::ConstantIntOp::create(builder, alloc.getLoc(), idx, 32); + return createSingleBufferView(builder, alloc, idxVal); +} + +Value triton::createIncrementModulo(OpBuilder &builder, Location loc, + Value counter, Value modulus, Value zero, + Value one, Value *outWrapCond) { + Value addOne = arith::AddIOp::create(builder, loc, counter, one); + Value outOfRangeCond = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sge, addOne, modulus); + if (outWrapCond) + *outWrapCond = outOfRangeCond; + return arith::SelectOp::create(builder, loc, outOfRangeCond, zero, addOne); +} + +///////////////////////////// +// LOWER TMA DESCRIPTORS +///////////////////////////// + +static void +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int maxStage) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = triton::gpu::GlobalScratchAllocOp::create( + rewriter, loc, triton::getPointerType(rewriter.getI8Type()), + maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); +} + +static Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = + arith::ConstantIntOp::create(builder, loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = arith::MulIOp::create(builder, loc, tmaSizeVal, counter); + return triton::AddPtrOp::create(builder, loc, alloc.getType(), alloc, offset); +} + +static LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numBuffers, Value one, Value zero, + triton::CoarseSchedule &schedule) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + auto auxBuilder = mlir::OpBuilder(forOp); + Value numBuffersVal = + arith::ConstantIntOp::create(auxBuilder, forOp.getLoc(), numBuffers, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + triton::OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp, + schedule); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = + subviewTMADescriptor(builder, builder.getLoc(), alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) { + return failure(); + } + ttng::TensormapFenceproxyAcquireOp::create(builder, nextBuf); + Value nextDesc = ttng::ReinterpretTensorDescOp::create( + builder, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo( + builder, builder.getLoc(), counter, numBuffersVal, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + IRRewriter rewriter(forOp); + nextCounter = triton::sinkValueRedefinition(rewriter, counter, nextCounter, + op->getBlock()); + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + makeDescOp.erase(); + } + return success(); +} + +scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp, + CoarseSchedule &schedule) { + llvm::MapVector tmaBufferMapping; + int maxStage = schedule.getNumStages() - 1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto wgMmaOp = dyn_cast(&op)) { + // Hopper only: Add one more buffer slice if there is a WarpGroupDotOp, + // as if it will be pipelined, we will effectively make the pipeline + // one stage longer. + maxStage += 1; + break; + } + } + allocTMABuffers(forOp, tmaBufferMapping, maxStage); + if (tmaBufferMapping.empty()) + return forOp; + + IRRewriter builder(forOp); + Location loc = forOp.getLoc(); + Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); + Value one = arith::ConstantIntOp::create(builder, loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + + forOp = addIterArgsToLoop(builder, forOp, newOperands); + + auto tmaCounters = ArrayRef(forOp.getBody()->getArguments()) + .slice(tmaCounterArgsStartIdx); + + // Update yield op with temporary yield values + auto forYield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters, + maxStage, one, zero, schedule))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + return forOp; +} + +DenseSet +triton::getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp, + std::function filter) { + DenseSet topLevelUsers; + SmallVector q; + for (auto &use : op->getUses()) + q.push_back(&use); + while (!q.empty()) { + auto use = q.pop_back_val(); + auto yieldOp = dyn_cast(use->getOwner()); + if (yieldOp && yieldOp->getParentOp() == forOp) { + for (auto &use : + forOp.getRegionIterArgs()[use->getOperandNumber()].getUses()) + q.push_back(&use); + continue; + } + // Don't count view operations as uses. Follow them through to their + // users. + if (use->getOwner()->hasTrait()) { + for (auto &use : use->getOwner()->getUses()) + q.push_back(&use); + continue; + } + if (filter && !filter(use->getOwner())) + continue; + Operation *topLevelUser = + forOp.getBody()->findAncestorOpInBlock(*use->getOwner()); + topLevelUsers.insert(topLevelUser); + } + return topLevelUsers; +} + +// Helper function that finds an operation based on a comparison predicate +static Operation *getUseOfPipelinedOp( + ArrayRef ops, scf::ForOp forOp, + triton::CoarseSchedule &schedule, + std::function filterUse, + std::function shouldPrefer) { + DenseSet topLevelUsers; + Operation *selectedUser = nullptr; + for (Operation *op : ops) { + auto users = triton::getTopLevelUsersInLoop(op, forOp, filterUse); + topLevelUsers.insert(users.begin(), users.end()); + } + for (Operation *topLevelUser : topLevelUsers) { + assert(schedule.count(topLevelUser) && "op user not found in the schedule"); + if (!selectedUser || shouldPrefer(topLevelUser, selectedUser)) { + selectedUser = topLevelUser; + } + } + return selectedUser; +} + +Operation * +triton::getFirstUseOfPipelinedOp(ArrayRef ops, scf::ForOp forOp, + triton::CoarseSchedule &schedule, + std::function filterUse) { + return getUseOfPipelinedOp( + ops, forOp, schedule, filterUse, + [&](Operation *candidate, Operation *current) { + auto [candidateStage, candidateCluster] = schedule[candidate]; + auto [currentStage, currentCluster] = schedule[current]; + + return candidateStage < currentStage || + (candidateStage == currentStage && + schedule.clusters.isBefore(candidateCluster, currentCluster)) || + (candidateStage == currentStage && + candidateCluster == currentCluster && + candidate->isBeforeInBlock(current)); + }); +} + +Operation * +triton::getLastUseOfPipelinedOp(ArrayRef ops, scf::ForOp forOp, + triton::CoarseSchedule &schedule, + std::function filterUse) { + return getUseOfPipelinedOp( + ops, forOp, schedule, filterUse, + [&](Operation *candidate, Operation *current) { + auto [candidateStage, candidateCluster] = schedule[candidate]; + auto [currentStage, currentCluster] = schedule[current]; + + return candidateStage > currentStage || + (candidateStage == currentStage && + schedule.clusters.isBefore(currentCluster, candidateCluster)) || + (candidateStage == currentStage && + candidateCluster == currentCluster && + current->isBeforeInBlock(candidate)); + }); +} + +void triton::removePipeliningAttributes(ModuleOp moduleOp) { + moduleOp->walk([&](Operation *op) { + op->removeAttr(mlir::triton::kLoopStageAttrName); + op->removeAttr(mlir::triton::kLoopClusterAttrName); + op->removeAttr(mlir::triton::kScheduledMaxStageAttrName); + }); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp new file mode 100644 index 0000000000..46c10b94c4 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -0,0 +1,419 @@ +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +namespace tt = mlir::triton; + +bool tt::CoarseSchedule::insertMinimum(Operation *op, int stage, + Cluster cluster) { + auto res = opToStageAndCluster.insert({op, {stage, cluster}}); + if (res.second) { + return true; + } + + auto &[existingStage, existingCluster] = res.first->second; + + // Always insert if the stage is earlier. + if (stage < existingStage) { + existingStage = stage; + existingCluster = cluster; + return true; + } + + // If the stage is later, no change. + if (stage > existingStage) { + return false; + } + + // If existingCluster is reachable from cluster, + // then cluster is earlier in the list + for (auto it = std::next(cluster); it != clusters.end(); ++it) { + if (it == existingCluster) { + if (existingCluster == cluster) + return false; + existingCluster = cluster; + return true; + } + } + + // Didn't change the cluster. + return false; +} + +bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster, + bool includeArg, bool insertIfEarlier) { + auto tryInsert = [&](Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster) { + if (!insertIfEarlier) + return insertIfAbsent(op, stage, cluster); + return insertMinimum(op, stage, cluster); + }; + + bool inserted = false; + for (Value operand : getNestedOperands(op)) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (tryInsert(defOp, stage, cluster)) { + inserted = true; + insertDepsOfOp(defOp, stage, cluster, includeArg, insertIfEarlier); + } + } + } + return inserted; +} + +void tt::CoarseSchedule::shrinkToFit() { + int minStage = std::numeric_limits::max(); + int maxStage = std::numeric_limits::min(); + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + auto [stage, cluster] = stageAndCluster; + minStage = std::min(minStage, stage); + maxStage = std::max(maxStage, stage); + } + for (auto &[op, stageAndCluster] : opToStageAndCluster) + stageAndCluster.first -= minStage; + numStages = maxStage - minStage + 1; +} + +// Split the cluster containing op into two clusters, one containing all +// operations before the op and one containing op and all operations after the +// op. Return the cluster containing op and all operations after the op. Do not +// split if the op is the first operation in the cluster. +tt::CoarseSchedule::Cluster +tt::CoarseSchedule::splitClusterBefore(Operation *op, scf::ForOp forOp) { + auto cluster = opToStageAndCluster[op].second; + std::optional newCluster = std::nullopt; + for (auto &_op : forOp.getBody()->without_terminator()) { + if (&_op == op) { + break; + } + if (opToStageAndCluster[&_op].second == cluster) { + if (!newCluster) { + newCluster = clusters.newBefore(cluster); + } + opToStageAndCluster[&_op].second = *newCluster; + } + } + return cluster; +} + +// Check if op a will show up before op b in the final unrolled code. +bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) const { + assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) && + "Operations must be in the schedule"); + auto [aStage, aCluster] = opToStageAndCluster.lookup(a); + auto [bStage, bCluster] = opToStageAndCluster.lookup(b); + if (aStage != bStage) { + return aStage < bStage; + } + if (aCluster != bCluster) { + return clusters.isBefore(aCluster, bCluster); + } + return a->isBeforeInBlock(b); +} + +bool tt::CoarseSchedule::isOpInEarlierCluster(Operation *a, + Operation *b) const { + assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) && + "Operations must be in the schedule"); + return clusters.isBefore(opToStageAndCluster.lookup(a).second, + opToStageAndCluster.lookup(b).second); +} + +bool tt::CoarseSchedule::isOpInSameCluster(Operation *a, Operation *b) const { + assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) && + "Operations must be in the schedule"); + return opToStageAndCluster.lookup(a).second == + opToStageAndCluster.lookup(b).second; +} + +SmallVector> +tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) const { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + auto it = opToStageAndCluster.find(&op); + if (it == opToStageAndCluster.end()) { + continue; + } + auto [stage, cluster] = it->second; + assert(cluster != Cluster{} && "Op with invalid cluster!"); + assert(stage < numStages && "Op with invalid stage!"); + int clusterId = *cluster; + assert(clusterId == std::distance(clusters.begin(), + ClusterList::const_iterator(cluster)) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back(make_tuple(&op, stage, cluster)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; +} + +std::vector> +tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) const { + SmallVector> + opsInOrder = getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) + schedule.push_back({op, stage}); + return schedule; +} + +void tt::CoarseSchedule::dump() { + assert(numStages > 0 && "Invalid number of stages"); + for (int i = 0; i < numStages; i++) { + llvm::dbgs() << "\n---- Ops in stage " << i << "\n"; + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::dbgs() << " cluster: " << *stageAndCluster.second + << ":\n\t" << *op << "\n"; + } + } + } +} + +static void setStageCluster(Operation *op, int stage, int cluster) { + auto ctx = op->getContext(); + op->setAttr(mlir::triton::kLoopStageAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), stage)); + op->setAttr(mlir::triton::kLoopClusterAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), cluster)); +} + +static std::pair getStageCluster(Operation *op) { + auto stage = op->getAttrOfType(tt::kLoopStageAttrName); + auto clusterId = op->getAttrOfType(tt::kLoopClusterAttrName); + assert(stage && clusterId && + "Operation is missing stage & cluster attribute"); + return {stage.getValue().getSExtValue(), clusterId.getValue().getSExtValue()}; +} + +static std::pair getMinMaxCluster(scf::ForOp &forOp) { + int minClusterId = -1, maxClusterId = -1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName) || + !op.hasAttr(mlir::triton::kLoopClusterAttrName)) + continue; + auto [_, cluster] = getStageCluster(&op); + if (maxClusterId < 0) { + minClusterId = cluster; + maxClusterId = cluster; + continue; + } + maxClusterId = cluster > maxClusterId ? cluster : maxClusterId; + minClusterId = cluster < minClusterId ? cluster : minClusterId; + } + return std::make_pair(minClusterId, maxClusterId); +} + +static std::optional tryGetMaxStage(scf::ForOp &forOp) { + std::optional maxStage = std::nullopt; + if (forOp->hasAttr(mlir::triton::kScheduledMaxStageAttrName)) { + return forOp + ->getAttrOfType(mlir::triton::kScheduledMaxStageAttrName) + .getValue() + .getSExtValue(); + } + return maxStage; +} + +// Set based on CoarseSchedule. +void tt::CoarseSchedule::serialize(scf::ForOp &forOp) const { + for (auto [op, stage, cluster] : getOpsInOrder(forOp)) { + setStageCluster(op, stage, *cluster); + } + + Builder b(forOp.getContext()); + int maxStages = numStages - 1; + if (auto maxStageAttr = tryGetMaxStage(forOp)) + maxStages = std::max(maxStages, *maxStageAttr); + forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName, + b.getI32IntegerAttr(maxStages)); +} + +// Create a CoarseSchedule based on forOp's . +LogicalResult tt::CoarseSchedule::deSerialize(scf::ForOp &forOp, + bool normalizeClusterId) { + auto [minClusterId, maxClusterId] = getMinMaxCluster(forOp); + std::optional maxStage = tryGetMaxStage(forOp); + if (!maxStage) { + return failure(); + } + numStages = *maxStage + 1; + + DenseMap clustersMap; + if (normalizeClusterId) { + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } + } else { + for (int i = 0; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } + } + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) + continue; + auto [stage, clusterId] = getStageCluster(&op); + insert(&op, stage, clustersMap[clusterId]); + } + return success(); +} + +// TODO: Should this be moved somewhere else? +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +// ============================================================ +// LinearizedIterator Implementation +// ============================================================ + +tt::CoarseSchedule::LinearizedIterator::LinearizedIterator( + scf::ForOp forOp, const CoarseSchedule &schedule, Operation *initialOp) + : forOp(forOp), schedule(&schedule), initialOp(initialOp), atEnd(false), + maxStages(schedule.getNumStages()) { + clusterBegin = schedule.clusters.begin(); + clusterEnd = schedule.clusters.end(); + opIt = forOp.getBody()->without_terminator().begin(); + opEnd = forOp.getBody()->without_terminator().end(); + + // Find the cluster containing initialOp and its stage + auto it = schedule.opToStageAndCluster.find(initialOp); + if (it != schedule.opToStageAndCluster.end()) { + auto [stage, cluster] = it->second; + clusterIt = cluster; + currStageLimit = stage; + // Find initialOp within its cluster + while (opIt != opEnd) { + Operation *op = &*opIt; + if (op == initialOp) { + break; + } + ++opIt; + } + // Move past initialOp to start iteration from the next op + ++opIt; + advanceToNextScheduledOp(); + } else { + atEnd = true; + currentOp = nullptr; + } +} + +void tt::CoarseSchedule::LinearizedIterator::advanceToNextScheduledOp() { + while (true) { + while (opIt != opEnd) { + Operation *op = &*opIt; + auto it = schedule->opToStageAndCluster.find(op); + if (it != schedule->opToStageAndCluster.end()) { + auto [stage, cluster] = it->second; + if (cluster == clusterIt) { + // Check if we've come back to initialOp + if (op == initialOp) { + // Check termination condition + if (currStageLimit >= maxStages) { + atEnd = true; + currentOp = nullptr; + return; + } + } + // Only yield if stage <= currStageLimit + if (stage <= currStageLimit) { + currentOp = op; + return; + } + } + } + ++opIt; + } + // Move to next cluster + ++clusterIt; + opIt = forOp.getBody()->without_terminator().begin(); + + // Wrap around to the beginning if we've reached the end + if (clusterIt == clusterEnd) { + clusterIt = clusterBegin; + // Increment stage limit as we are in the next iteration. + currStageLimit++; + } + } +} + +tt::CoarseSchedule::LinearizedIterator & +tt::CoarseSchedule::LinearizedIterator::operator++() { + if (atEnd) + return *this; + ++opIt; + advanceToNextScheduledOp(); + return *this; +} + +tt::CoarseSchedule::LinearizedIterator +tt::CoarseSchedule::LinearizedIterator::operator++(int) { + LinearizedIterator tmp = *this; + ++(*this); + return tmp; +} + +Operation *tt::CoarseSchedule::LinearizedIterator::operator*() const { + return currentOp; +} + +bool tt::CoarseSchedule::LinearizedIterator::operator==( + const LinearizedIterator &other) const { + if (atEnd && other.atEnd) + return true; + if (atEnd != other.atEnd) + return false; + return currentOp == other.currentOp; +} + +bool tt::CoarseSchedule::LinearizedIterator::operator!=( + const LinearizedIterator &other) const { + return !(*this == other); +} + +void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, /*includeArg=*/false, + /*insertIfEarlier=*/true); + } + } +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp new file mode 100644 index 0000000000..2de92d854d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp @@ -0,0 +1,415 @@ +#include "mlir/IR/Dominance.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace mlir::triton::gpu { + +//===----------------------------------------------------------------------===// +// scheduleLoops +//===----------------------------------------------------------------------===// + +template bool containsAny(scf::ForOp forOp) { + WalkResult result = forOp.walk([&](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return result.wasInterrupted(); +} + +// Return true if the preconditions for pipelining the loop are met. +bool isSafeToPipeline(scf::ForOp forOp) { + // Skip loop with distance > 1. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + // Skip loops with barriers, asserts or prints + if (containsAny(forOp)) + return false; + + return true; +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, + /*includeArg=*/true, + /*insertIfEarlier=*/true); + } else { + CoarseSchedule::ClusterHash clusterHash = + CoarseSchedule::hashCluster(cluster); + if (dist1Cluster.count(clusterHash) == 0) { + dist1Cluster[clusterHash] = + schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, + dist1Cluster[clusterHash]); + schedule.insertDepsOfOp(defOp, stage + 1, + dist1Cluster[clusterHash], + /*includeArg=*/true, + /*includeIfEarlier=*/true); + } + } + } + } + } + } +} + +void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue) { + int numStages = schedule.getNumStages(); + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster; + if (schedule.count(op)) + opCluster = schedule[op].second; + else + opCluster = opToCluster[op]; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +namespace { +bool hasLatenciesAssigned(scf::ForOp forOp, + const DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + return true; + } + return false; +} + +CoarseSchedule scheduleKeyOps(scf::ForOp forOp, + const DenseMap &opLatency) { + llvm::MapVector opToStage; + // Find terminator for later reference + auto terminator = cast(forOp.getBody()->getTerminator()); + // Determine all operations that have a non-zero latency + SmallVector latOps; + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + latOps.push_back(&op); + } + // If no latency ops, nothing to schedule + if (latOps.empty()) + return CoarseSchedule(0); + + DominanceInfo domInfo(forOp); + // Compute the longest path to the yield for each operation reachable + // from any latency operation. + DenseMap distance; + std::function computeDistance = [&](Operation *op) -> int { + auto it = distance.find(op); + if (it != distance.end()) + return it->second; + // Compute max distance among all users that are inside the loop body + int maxDist = -1; + for (Operation *user : op->getUsers()) { + // Only consider users inside the same block and not the terminator + Operation *inBlockUser = forOp.getBody()->findAncestorOpInBlock(*user); + if (!inBlockUser || inBlockUser == terminator) + continue; + int distUser = computeDistance(inBlockUser); + if (distUser > maxDist) + maxDist = distUser; + } + int lat = 0; + if (opLatency.count(op)) + lat = opLatency.lookup(op); + // If an op has no users (maxDist == -1) but has latency, we include its + // latency otherwise it contributes 0 to the distance. + int d = lat + (maxDist < 0 ? 0 : maxDist); + distance[op] = d; + return d; + }; + + // Compute distances for all latency-starting ops + int maxDistance = 0; + for (Operation *latOp : latOps) { + int d = computeDistance(latOp); + if (d > maxDistance) + maxDistance = d; + } + + // Assign stage to each op reachable from a latency op + for (auto [op, dist] : distance) { + // We only schedule ops that are downstream of a latency op + // (had a non-negative distance due to a latency op). + if (dist >= 0) + opToStage[op] = maxDistance - dist; + } + + auto stages = llvm::make_second_range(opToStage); + int maxStage = *llvm::max_element(stages); + CoarseSchedule schedule(maxStage + 1); + SmallVector clusters(maxStage + 1); + for (int i = 0; i <= maxStage; i++) { + clusters[i] = schedule.clusters.newAtBack(); + } + // Assign ops to the clusters in reverse-stage order; + // ops with higher stage numbers are assigned first. This way we will + // end up with roughly reverse program order in the clusters. + for (auto [op, stage] : opToStage) + schedule.insert(op, stage, clusters[maxStage - stage]); + + // Move `scf.if` ops in the current schedule (forward slice of the latency + // ops) into a new epilogue cluster at the end of the schedule, pushing them + // as close to the end of the loop body as possible. + CoarseSchedule::Cluster epilogue = schedule.clusters.newAtBack(); + for (auto [op, stage] : opToStage) { + auto ifOp = dyn_cast(op); + if (!ifOp) + continue; + // If the `scf.if` op itself is a latency op, skip it. + if (opLatency.contains(ifOp)) + continue; + // Ensure this does not create scheduling conflicts by ensuring the forward + // slice of the `scf.if` does not contain ops that are already scheduled, as + // this will cause the `scf.if` to be scheduled after its dependents. + SetVector slice; + getForwardSlice(ifOp, &slice); + if (llvm::any_of(slice, [&](Operation *op) { return opToStage.count(op); })) + continue; + schedule.insert(ifOp, stage, epilogue); + } + + return schedule; +} + +// Get an initial schedule for the loop. This is the base schedule from which +// the rest of the pass will backward propagate dependencies. +CoarseSchedule getInitialSchedule(scf::ForOp forOp, + const DenseMap &opLatency) { + if (!isSafeToPipeline(forOp)) + return CoarseSchedule(0); + + // If the loop has assigned latencies, use them to determine the initial + // schedule. + if (hasLatenciesAssigned(forOp, opLatency)) + return scheduleKeyOps(forOp, opLatency); + + // If the loop has an existing schedule, use it as the base schedule. + CoarseSchedule schedule; + if (forOp->hasAttr(kWarpSpecializeAttrName) && + succeeded(schedule.deSerialize(forOp))) { + // The loop was partitioned from a warp-specialized loop, meaning it can + // have a partial view of the original loop stages. Re-schedule the loop + // root at the stages of the latency ops to prune unnecessary stages. + auto isLatencyOp = [&](Operation &op) { + return opLatency.count(&op) || + isa(op); + }; + + // If there are no latency ops or all latency ops are in the same stage, we + // don't need to pipeline the loop. Return a new schedule with everything + // assigned to the same stage. + DenseSet latencyStages; + auto ops = forOp.getBody()->without_terminator(); + for (Operation &op : llvm::make_filter_range(ops, isLatencyOp)) { + // FIXME: This should assert all latency ops have an assigned stage. + if (schedule.count(&op)) + latencyStages.insert(schedule[&op].first); + } + if (latencyStages.size() <= 1) { + CoarseSchedule normalized(/*numStages=*/1); + auto cluster = normalized.clusters.newAtFront(); + for (Operation &op : ops) + normalized.insert(&op, 0, cluster); + return normalized; + } + + schedule.shrinkToFit(); + return schedule; + } + + return CoarseSchedule(0); +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.omitUsesFromAbove = false; + (void)getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + if (!ifsToStage.empty()) { + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insertIfAbsent(ifOp, stage, prologueCluster); + } + } + + // Other IfOps should be pushed to the end. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifsToStage.count(ifOp) == 0) { + schedule.insertIfAbsent(ifOp, numStages - 1, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency) { + // Based on the latencies, schedule the key ops to the stages. + CoarseSchedule schedule = getInitialSchedule(forOp, opLatency); + if (schedule.empty()) + return; + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Initial coarse schedule:\n" << forOp << "\n"; + }); + // Schedule the dependencies + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with prologue and epilogue:\n" << forOp << "\n"; + }); + scheduleDependencies(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with dependencies:\n" << forOp << "\n"; + }); + scheduleDistanceOneDependencies(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with dist 1:\n" << forOp << "\n"; + }); + scheduleRemainingToLastStage(forOp, schedule, afterPrologue); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Final coarse schedule:\n" << forOp << "\n"; + }); + + // Write the schedule to the IR + schedule.serialize(forOp); +} + +/// Schedule the loops based on the latencies assigned to the operations. +void scheduleLoops(ModuleOp moduleOp) { + DenseMap opLatency = deserializeLatencies(moduleOp); + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + if (loops.empty()) + return; + for (auto forOp : loops) { + scheduleLoop(forOp, opLatency); + } +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_TRITONGPUSCHEDULELOOPS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct ScheduleLoops : public impl::TritonGPUScheduleLoopsBase { + using TritonGPUScheduleLoopsBase::TritonGPUScheduleLoopsBase; + + void runOnOperation() override { scheduleLoops(getOperation()); } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 0000000000..d36a115bf0 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,228 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/LoopPeeling.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPIPELINE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) { + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + for (scf::ForOp forOp : loops) { + if (getNumStagesOrDefault(forOp, numStages) >= 1) + mlir::triton::asyncLaunchDots(forOp); + } +} + +static bool hasMMAv5WaitsInLastStage(scf::ForOp forOp, + CoarseSchedule &schedule) { + int maxStage = schedule.getNumStages() - 1; + bool hasMMAv5 = false; + bool hasWaitInLastStage = false; + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op) && + schedule[&op].first == maxStage) { + hasWaitInLastStage = true; + } + if (isa(op)) { + hasMMAv5 = true; + } + } + return hasMMAv5 && hasWaitInLastStage; +} + +static void expandLoops(ModuleOp moduleOp) { + DenseSet peeledMaskOps; + auto processPeeledEpilogueOp = [&](RewriterBase &rewriter, Operation *op, + bool isEpilogue) -> Operation * { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto predOp = dyn_cast(op)) { + if (isEpilogue) { + // Return false for the predicate of the peeled iteration + return mlir::arith::ConstantIntOp::create( + rewriter, predOp.getLoc(), predOp.getResult().getType(), 0); + } + if (predOp.getStage() == predOp.getMaxStage() - 1) { + return mlir::arith::ConstantIntOp::create( + rewriter, predOp.getLoc(), predOp.getResult().getType(), 1); + } + return triton::emitPredicateForStage( + rewriter, predOp.getIv(), predOp.getUb(), predOp.getStep(), + predOp.getMaxStage(), predOp.getStage()) + .getDefiningOp(); + } + if (auto maskOp = dyn_cast(op)) { + if (isEpilogue) { + peeledMaskOps.insert(maskOp); + } + } + return op; + }; + + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp))) { + continue; + } + + std::vector> finalSchedule = + schedule.createFinalSchedule(forOp); + triton::PipeliningOption options; + options.supportDynamicLoops = true; + options.peelEpilogue = false; + options.predicateFn = wrapInMaskOp; + options.getScheduleFn = + [&](scf::ForOp forOp, + std::vector> &schedule) { + schedule = finalSchedule; + }; + + // Testing feature: allow for unresolved predicate stage ops + // in the loop body. + bool keepPredicateStage = forOp->hasAttr("__test_keep_predicate_stage"); + // TODO: Enable epilogue peeling for warp specialized loops + // Heuristic: only peel epilogue for MMAv5 loops with waits in the last + // stage + bool customEpiloguePeeling = + hasMMAv5WaitsInLastStage(forOp, schedule) && + !forOp->getParentOfType() && + !keepPredicateStage; // do not peel if we are testing the stage + // predication + + if (keepPredicateStage || customEpiloguePeeling) { + options.emitPredicateStageFn = + [](RewriterBase &rewriter, Value inductionVar, Value upperBound, + Value step, uint64_t maxStage, uint64_t stage) { + return triton::gpu::PredicateStageOp::create( + rewriter, inductionVar.getLoc(), inductionVar, upperBound, step, + maxStage, stage); + }; + } + IRRewriter rewriter(forOp); + FailureOr newForOp = + triton::pipelineForLoop(rewriter, forOp, options); + + if (failed(newForOp)) { + continue; + } + forOp = *newForOp; + if (customEpiloguePeeling) { + mlir::triton::peelLoopEpilogue(forOp, processPeeledEpilogueOp); + } + + // Prune all the statically dead mask ops in the epilogue. This is a + // hack, ideally we should do it for all the mask ops, but it is incorrect + // if we have speculatively executed async cp operations that will store to + // shmem even if the mask is false. + for (auto maskOp : peeledMaskOps) { + rewriter.setInsertionPoint(maskOp); + if (isConstantIntValue(maskOp.getPred(), 0)) { + SmallVector results; + for (auto result : maskOp->getResults()) { + auto poisonOp = mlir::ub::PoisonOp::create(rewriter, maskOp->getLoc(), + result.getType()); + results.push_back(poisonOp); + } + maskOp->replaceAllUsesWith(results); + maskOp->erase(); + } + } + peeledMaskOps.clear(); + } + assert(moduleOp.getOps().empty() && + "PredicateStageOp should be resolved after the pipeline expansion"); + assert(verify(moduleOp).succeeded()); + resolveMaskOp(moduleOp); +} + +struct PipelinePass : public impl::TritonGPUPipelineBase { + + using impl::TritonGPUPipelineBase::TritonGPUPipelineBase; + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + // Transform the loop by introducing async operations to prepare it for + // pipeline expansion. + lowerLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() + << "// -----// SoftwarePipeliner internal IR Dump After: LowerLoops\n" + << moduleOp << "\n\n\n"; + } + + // Apply the pipeline expansion. + expandLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() << "// -----// SoftwarePipeliner internal IR Dump After: " + "ExpandLoops\n" + << moduleOp << "\n\n\n"; + } + + // Cleanup the IR from the pipeline attributes. + removePipeliningAttributes(moduleOp); + + pipelineWgmma(moduleOp, numStages); + + // schedule the waits + mlir::triton::updateWaits(getOperation()); + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsGreedily(getOperation(), std::move(patterns)).failed()) + return signalPassFailure(); + + { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp, numStages) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) { + mlir::triton::pipelineTMAStores(forOp); + } + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp new file mode 100644 index 0000000000..2b753f3c6b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -0,0 +1,125 @@ +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +struct TMAStore { + Operation *op; + mlir::TypedValue desc; + mlir::TypedValue src; +}; + +static SmallVector getTMAStores(scf::ForOp forOp) { + SmallVector tmaStores; + + forOp.getBody()->walk([&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + tmaStores.push_back({storeOp, storeOp.getDesc(), storeOp.getSrc()}); + // Don't walk into nested loops. + } else if (isa(op)) { + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + + return tmaStores; +} + +static Value createAlloc(scf::ForOp &forOp, const TMAStore &store) { + OpBuilder builder(forOp); + RankedTensorType ty = store.src.getType(); + auto encoding = + triton::nvidia_gpu::getEncodingFromDescriptor(store.op, ty, store.desc); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); + Type memdescType = + ttg::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); + Value alloc = + ttg::LocalAllocOp::create(builder, store.op->getLoc(), memdescType); + return alloc; +} + +static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, + Value alloc) { + OpBuilder builder(store.op); + Location loc = store.op->getLoc(); + RankedTensorType ty = store.src.getType(); + + // Put wait before the local_store make the store truly async. We know + // that we are the only user of the CopyLocalToGlobal. + ttng::TMAStoreWaitOp::create(builder, loc, 0); + ttg::LocalStoreOp::create(builder, loc, store.src, alloc); + ttng::FenceAsyncSharedOp::create(builder, loc, false); + auto desc = store.desc; + if (auto storeOp = dyn_cast(store.op)) { + ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc, + storeOp.getIndices(), alloc); + } else if (auto reduceOp = dyn_cast(store.op)) { + ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc, + reduceOp.getIndices(), alloc); + } else { + auto scatterOp = cast(store.op); + ttng::AsyncTMAScatterOp::create(builder, loc, desc, scatterOp.getXOffsets(), + scatterOp.getYOffset(), alloc); + } + + store.op->erase(); +} + +static void lowerTMADescriptorCreation(scf::ForOp forOp) { + // Use max_stage=3 to double buffer the descriptor. + triton::CoarseSchedule schedule(3); + triton::lowerTMADescriptors(forOp, schedule); +} + +bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { + SmallVector tmaStores = getTMAStores(forOp); + if (tmaStores.empty()) + return false; + + DenseMap storeToAlloc; + DenseMap, Type>, Value> allocs; + for (const TMAStore &store : tmaStores) { + // Reuse allocations for stores of the same shape and types. This allows + // saving shared memory usage. It is valid since we have a wait 0 before + // every local_store. We could pipeline more aggressively if we didn't + // reuse but there is a tradeoff with shared memory usage. + RankedTensorType srcTy = store.src.getType(); + auto key = std::make_pair(srcTy.getShape(), srcTy.getElementType()); + Value &alloc = allocs[key]; + if (!alloc) { + alloc = createAlloc(forOp, store); + } + storeToAlloc[store.op] = alloc; + } + + bool hasDeviceSideTMA = llvm::any_of(tmaStores, [](const TMAStore &store) { + return !triton::isHostSideDescriptor(store.desc); + }); + for (const TMAStore &store : tmaStores) { + createTMAAsyncCopy(forOp, store, storeToAlloc[store.op]); + } + + // Deallocate shared memory buffers. + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + ttng::TMAStoreWaitOp::create(builder, forOp->getLoc(), 0); + for (auto it : storeToAlloc) { + ttg::LocalDeallocOp::create(builder, forOp->getLoc(), it.second); + } + + if (hasDeviceSideTMA) { + // This is a bit coarse as it would multibuffer any descriptor in the loop + // but it likely to not have a big impact. + lowerTMADescriptorCreation(forOp); + } + return true; +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp new file mode 100644 index 0000000000..7602bb4765 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp @@ -0,0 +1,32 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINELOWERLOOP +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TestPipelineLowerLoop + : public impl::TritonGPUTestPipelineLowerLoopBase { + using impl::TritonGPUTestPipelineLowerLoopBase< + TestPipelineLowerLoop>::TritonGPUTestPipelineLowerLoopBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + lowerLoops(m); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp new file mode 100644 index 0000000000..6720d5d96b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp @@ -0,0 +1,769 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-wgmma-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// Returns whether the dot is such that: +// 1. The LHS comes from registers and +// 1.1 The LHS is defined inside the loop +// 1.2. The LHS does not come from another dot +// For these dots, we assume that we cannot rewrite their +// operands until the previous dot has finished +static bool rsDotNeedsWait(Operation *dot, scf::ForOp forOp) { + auto dotOp = dyn_cast(dot); + if (!dotOp) + return false; + auto a = dotOp.getA(); + if (!isa(a.getType())) { + return false; + } + if (forOp.isDefinedOutsideOfLoop(a)) { + return false; + } + if (auto cvt = dyn_cast(a.getDefiningOp())) { + return !isa( + cvt.getSrc().getType().getEncoding()); + } + return true; +} + +/// Find the minimum number of async_commit_group ops between the wait +/// and the associated async_commit_group. This can be safely used as the wait +/// number. +static int minNumInterleavedCommitOps(Operation *waitOp) { + auto countCommitsBetween = [](Operation *op1, Operation *op2) { + int count = 0; + for (auto op = op1; op != op2; op = op->getNextNode()) { + if (isa(op)) + count++; + // Intentionally skip block ops' children. This will give us + // convervatively low number of insert ops. + } + return count; + }; + + int minCommitNumber = INT_MAX; + + // DFS the def chain of the extract op to find the insert op. On each path + // we calculate the number of async_commit. Then we select the minimum number + // of async_commit ops among all the paths. + std::function minOverHistories = + [&](Value val, Operation *sinkOp, int thisHistorySum) -> int { + if (Operation *defOp = val.getDefiningOp()) { + thisHistorySum += countCommitsBetween(defOp->getNextNode(), sinkOp); + minCommitNumber = std::min(minCommitNumber, thisHistorySum); + return minCommitNumber; + } + if (auto arg = mlir::dyn_cast(val)) { + Block *block = arg.getOwner(); + auto forOp = dyn_cast(block->getParentOp()); + + // Failed to track, return 0 conservatively. + if (!forOp) + return 0; + + Operation *firstForInst = &*forOp.getBody()->begin(); + int insertsBetween = countCommitsBetween(firstForInst, sinkOp); + thisHistorySum += insertsBetween; + if (thisHistorySum >= minCommitNumber) + return minCommitNumber; + + // get the value assigned to the argument coming from outside the loop + Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; + int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); + + // get the value assigned to the argument coming from the previous + // iteration + Operation *yieldOp = block->getTerminator(); + Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); + int min2 = minOverHistories(prevVal, yieldOp, thisHistorySum); + return std::min(std::min(min1, min2), minCommitNumber); + } + // Failed to track, return 0 conservatively. + return 0; + }; + + if (waitOp->getNumOperands() != 1) + return 0; + Value val = waitOp->getOperand(0); + // If the value resides in a region other than the region of the wait op, then + // the wait op must be in some nested region. Measure the number of commits + // between the definition value and the parent op. + // TODO: We could measure commits in nested regions along the path if + // necessary. + while (waitOp->getParentRegion() != val.getParentRegion()) + waitOp = waitOp->getParentOp(); + int minCommits = minOverHistories(val, waitOp, 0); + return minCommits; +} + +/// Update wait op number by analyzing the number of async_commit_group ops +/// along all paths. +void mlir::triton::updateWaits(ModuleOp module) { + llvm::SmallSetVector waitOps; + module.walk([&](ttg::AsyncWaitOp waitOp) { + int minNumCommits = minNumInterleavedCommitOps(waitOp); + waitOp.setNum(minNumCommits); + waitOps.insert(waitOp); + }); + tt::combineRedundantWaitOps(waitOps); +} + +// Add the given values as operands of the given wait, and replace all uses of +// the values with the wait. Also adds related MemDesc's to the wait. +// +// Threading %a through the wait transforms +// +// %a = <...> +// (%x', %y') = ttng.async_wait %x, %y +// %b = fn(%a) +// +// into +// +// %a = <...> +// (%x', %y', %a') = ttng.async_wait %x, %y, %a +// %b = fn(%a') +// +// The wait must dominate all uses of the elements of `values`. +// +// In addition to adding each value from `values` to the wait, this function +// also adds some MemDesc's to the wait. The idea is that if you have +// +// %alloc = ttg.local_alloc ... +// %a = ttng.warp_group_dot %alloc +// %a1 = ttng.warp_group_dot_wait %a +// +// then we want the wait to depend on %alloc as well as %a. This extends the +// live range of %alloc, so that it won't be destroyed until after the dot is +// waited on. +// +// Specifically, this function finds all warp_group_dot ops that elements of +// `values` depend on. Then it adds the MemDesc operands of those dots to the +// wait. +static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, + MutableArrayRef values) { + IRRewriter builder(wait.getContext()); + builder.setInsertionPoint(wait); + + // Operands are only added to the wait through this function, so we can have + // the invariant that the wait has no duplicates. This makes things a bit + // easier below. + size_t origNumOperands = wait.getNumOperands(); + SetVector newOperands(wait.getOperands().begin(), + wait.getOperands().end()); + assert(newOperands.size() == origNumOperands && + "Wait op has duplicate operands."); + + newOperands.insert(values.begin(), values.end()); + + // Find memdefs depended on by `values` through async dot ops. + SmallVector asyncDots; + for (Value v : values) { + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.filter = [&](Operation *op) { + if (auto dot = dyn_cast(op)) { + asyncDots.push_back(dot); + return false; + } + return op->getBlock() == wait->getBlock(); + }; + SetVector slice; + (void)getBackwardSlice(v, &slice, options); + } + + for (ttng::WarpGroupDotOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { + if (isa(operand.getType())) { + newOperands.insert(operand); + } + } + } + + // We can't use replaceWithNewOp because we're changing the number of return + // values in the operation. + auto newWait = ttng::WarpGroupDotWaitOp::create( + builder, wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); + + auto dominatedByNewWait = [&](OpOperand &operand) { + auto opInThisBlock = + newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); + return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); + }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); + if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); + if (!isa(operand.getType())) + operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); +} + +// Split the LHS of a RSWGMMADot operation into multiple +// tensors of size MxnewK via SplitOps +SmallVector splitLhs(OpBuilder &builder, + TypedValue lhs, int64_t newK) { + auto loc = lhs.getLoc(); + auto type = lhs.getType(); + auto rank = type.getRank(); + auto shape = to_vector(type.getShape()); + auto nSplits = shape.back() / newK; + assert(nSplits > 1); + // Reshape K == 2x..x2xnewK + shape.pop_back(); + for (int i = 1; i < nSplits; i *= 2) { + shape.push_back(2); + } + shape.push_back(newK); + lhs = tt::ReshapeOp::create(builder, loc, shape, lhs); + // We want to split first the slowest running dim, then the second slowest, + // etc. + auto transOrder = to_vector(llvm::seq(rank - 1)); + transOrder.push_back(shape.size() - 1); + llvm::append_range(transOrder, llvm::reverse(llvm::seq( + rank - 1, (int64_t)shape.size() - 1))); + lhs = tt::TransOp::create(builder, loc, lhs, transOrder); + // We split recursively + SmallVector curr; + SmallVector ret = {lhs}; + for (int i = 1; i < nSplits; i *= 2) { + curr = ret; + ret.clear(); + for (auto v : curr) { + auto split = tt::SplitOp::create(builder, loc, v); + ret.push_back(split.getResult(0)); + ret.push_back(split.getResult(1)); + } + } + + auto mmav3Type = + type.clone(cast(ret.front().getType()).getShape()); + // Convert the LHS to mmav3 layout + for (auto &v : ret) { + v = ttg::ConvertLayoutOp::create(builder, loc, mmav3Type, v); + // These convert_layout ops are noops by construction + assert(isNoop(v.getDefiningOp())); + } + assert(ret.size() == nSplits); + return ret; +} + +// Split the RHS of a RSWGMMADot operation into multiple multiple +// tensors of size newKxN via MemDescSubslice +SmallVector splitRhs(OpBuilder &builder, + TypedValue rhs, int64_t newK) { + auto loc = rhs.getLoc(); + auto type = rhs.getType(); + auto rank = type.getRank(); + auto kDim = rank - 2; + auto nSplits = type.getShape()[kDim] / newK; + auto shape = llvm::to_vector(type.getShape()); + shape[kDim] = newK; + SmallVector offsets(rank, 0); + auto newType = ttg::MemDescType::get( + shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(), + /*isMutable=*/false, type.getAllocShape()); + SmallVector ret; + for (int i = 0; i < nSplits; i++) { + offsets[kDim] = i * newK; + Value newSmem = + ttg::MemDescSubsliceOp::create(builder, loc, newType, rhs, offsets); + ret.push_back(newSmem); + } + return ret; +} + +std::vector splitRSDot(ttng::WarpGroupDotOp dotOp) { + // Splits wgmma(tensor, shmem, acc) into + // wgmma(tensor[:, :K//2], shmem[:K//2, :], acc) + // wgmma(tensor[:, K//2:], shmem[K//2:, :], acc) + // which allows for in-register pipelining of the wgmmas. + // + // Theoretically, it may be beneficial to split even further which allows more + // fine-grained overlapping of the wgmma ops but empirically 2 splits gave the + // best performance. In future this may be something we want to allow the user + // to tune. + if (!isa(dotOp.getA().getType())) { + return {dotOp}; + } + + auto a = cast>(dotOp.getA()); + auto b = cast>(dotOp.getB()); + auto origK = a.getType().getShape().back(); + auto instrK = cast(dotOp.getType().getEncoding()) + .getInstrShape()[2]; + // Nothing to split + if (origK <= instrK) { + return {dotOp}; + } + constexpr int numSplits = 2; + uint32_t newK = origK / numSplits; + + assert(origK % newK == 0 && "origK must be divisible by newK"); + auto builder = OpBuilder(dotOp); + auto loc = dotOp.getLoc(); + auto lhss = splitLhs(builder, a, newK); + auto rhss = splitRhs(builder, b, newK); + assert(lhss.size() == numSplits && "lhs must have the same number of splits"); + assert(rhss.size() == numSplits && "rhs must have the same number of splits"); + + Value useC = dotOp.getUseC(); + Value C = dotOp.getC(); + uint32_t numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc(); + std::vector dots; + for (int i = 0; i < numSplits; i++) { + // 2**30 is to prevent the subtile from adding + // extra imprecise accumulator, See WGMMA.cpp + auto take = std::min(numImpreciseAccLeft, newK); + uint32_t numImpreciseAcc = (take == newK) ? (1u << 30) : take; + numImpreciseAccLeft -= take; + + auto dot = ttng::WarpGroupDotOp::create( + builder, loc, dotOp.getType(), lhss[i], rhss[i], C, useC, + dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync()); + dots.push_back(dot); + C = dot.getResult(); + useC = {}; + } + dotOp.replaceAllUsesWith(dots.back().getResult()); + dotOp.erase(); + return dots; +} + +// Apply splitRSDot to all dots in the input list. +llvm::MapVector +splitRSDots(const llvm::MapVector &dots) { + llvm::MapVector ret; + for (auto [dot, iterArgIdx] : dots) { + auto newDots = splitRSDot(cast(dot)); + for (auto newDot : newDots) { + ret.insert({newDot, iterArgIdx}); + } + } + return ret; +} + +// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot, +// needs a wait immediately after it. +// +// In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent +// MMAv3 ops as either ttng.warp_group_dot {isAsync=True} or ttng.warp_group_dot +// {isAsync=False}. But even if we use ttng.warp_group_dot {isAsync=True}, the +// conservative thing is to make a dot "effectively synchronous" by inserting a +// `ttng.warp_group_dot_wait {pendings=0}` right after it. +// +// We can omit the wait and create a "properly async" dot if all of the +// following are true. +// +// 1. All operands that touch shared memory are multi-buffered, i.e. can't read +// an incomplete value while it's being written asynchronously by a load. +// 1a. If operand A is in registers, these registers cannot be updated +// inside +// the loop. +// **Exception** if the operand is produced by a preceding WGMMA, +// then this op can be properly async. Either the f16 shortcut is +// possible and the WGMMA's can run back-to-back (see rule 3 below), or +// elementwise truncate is needed, in which case the preceding WGMMA is +// not async and a WarpGroupDotWait is inserted right after, which +// guarantees exclusive access to the operand registers. +// +// 2. If the dot is used by any op in the loop, it must be used under an `if`, +// and will be synced with a `wait 0` at the beginning of the `if` block. +// +// 3. During iteration i, between the start of the loop up until the first +// `ttng.warp_group_dot_wait {pendings=0}` op, the result of the dot from +// iteration i-1 is consumed only by other MMAv3 dots as the `c` operand. +// +// This is safe because the following pseudo-PTX is valid: +// +// %accum = warp_group_dot %a1, %b1, %c1 +// %accum = warp_group_dot %a2, %b2, %accum +// +// That is, the second async dot can use the result of the first one without +// an intervening wait. However, the only operation that can legally read +// %accum before the wait is another warp_group_dot, and this only works for +// the `c` operand, not `a` or `b`. See +// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence +// (ttng::WarpGroupDotOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two +// ttng::WarpGroupDotOps don't have to correspond to wgmma.async ops with +// the same shapes as specified in the docs, because there's an intervening +// fence.) +// +// If the op can be properly async, this function returns the index of the dot +// in the loop's iter_args. (Rule (2) above ensures this is well-defined.) +// +static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, + scf::ForOp forOp) { + LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); + + auto checkOperand = [&](Value operand) { + // We can always make RSGEMM async s long as the RHS can be multi-buffered + if (isa(operand.getType())) { + return true; + } + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescIndex op. Only ConvertLayout and view ops are + // allowed in between. + Value transitiveOperand = operand; + DenseSet visitedBlockArgs; + while (!forOp.isDefinedOutsideOfLoop(transitiveOperand)) { + if (auto *definingOp = transitiveOperand.getDefiningOp()) { + if (isa(definingOp)) { + transitiveOperand = definingOp->getOperand(0); + continue; + } + return isa(definingOp); + } + auto blockArg = cast(transitiveOperand); + // We know that the dotOp is a top level operation in the loop body, and + // we have already checked that transitiveOperand is not defined outside + // the loop, therefore the block arg must be an iter arg of this loop. + assert(dotOp->getParentOp() == forOp); + assert(blockArg.getOwner() == forOp.getBody()); + // If we have already visited this block arg, that means that it + // participates in a cycle containing only permitted operations. The + // initial value therefore originates outside the loop, making this valid. + if (!visitedBlockArgs.insert(blockArg).second) + return true; + transitiveOperand = forOp.getTiedLoopYieldedValue(blockArg)->get(); + } + return true; + }; + + // Rule 1: All shmem operands are multi-buffered. + // We don't have to call checkOperand on getC() because it's always in + // registers, never in shmem. + assert(isa(dotOp.getC().getType().getEncoding())); + if (!checkOperand(dotOp.getA()) || !checkOperand(dotOp.getB())) { + LDBG("Can't make dot async because shmem operands aren't multi-buffered"); + return std::nullopt; + } + + // Rule 2: The dot cannot be unconditionally used by any op in the loop. + // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. + int iterArgIdx = -1; + Value iterArg = nullptr; + SmallVector> queue; + for (auto &use : dotOp->getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + while (!queue.empty()) { + auto [user, argIdx] = queue.pop_back_val(); + if (user->getParentOp() == forOp) { + // We support noops in between the dot and the yield + if (isNoop(user)) { + for (auto &use : user->getResult(0).getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + continue; + } + if (isa(user)) { + if (iterArg) { + // The dot is used by the loop's yield, but we can't have any other + // uses. + LDBG("Can't make dot async because dot is used by multiple ops in " + "the loop."); + return std::nullopt; + } + iterArgIdx = argIdx; + iterArg = forOp.getRegionIterArg(argIdx); + continue; + } + LDBG("Can't make dot async because dot is unconditionally used in the " + "loop."); + return std::nullopt; + } + if (auto ifOp = dyn_cast(user->getParentOp())) { + if (isa(user)) { + // The result is returned by the if, follow it further. + auto uses = ifOp.getResult(argIdx).getUses(); + for (auto &use : uses) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + } + } else { + return std::nullopt; + } + } + + // The dot result is not used by the loop yield. This could happen if it is + // dead, or if it is only used inside (but not yielded by) an scf::IfOp. + if (!iterArg) + return std::nullopt; + + // Rule 2.1: We don't make the dot async if the accumulator is not fp32. + if (!dotOp.getC().getType().getElementType().isF32()) { + LDBG("Can't make dot async because the accumulator is not fp32"); + return std::nullopt; + } + + // Rule 3a: Check that every use of the dot’s result (iterArg) eventually + // reaches a WarpGroupDotOp (with use index 2), possibly after passing through + // a chain of noops + std::function isTransitivelyWarpGroupDot = + [&](OpOperand &use) -> bool { + Operation *user = use.getOwner(); + if (isa(user)) + return use.getOperandNumber() == 2; + if (isNoop(user)) + return llvm::all_of(user->getResult(0).getUses(), + isTransitivelyWarpGroupDot); + return false; + }; + + if (llvm::all_of(iterArg.getUses(), isTransitivelyWarpGroupDot)) + return iterArgIdx; + + // Rule 3b: Are all users of the dot's result from iteration i-1 after the + // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be + // properly async, but we have to thread its result from iteration i-1 through + // the wait. + auto waitOps = forOp.getBody()->getOps(); + auto firstWaitOpIter = llvm::find_if( + waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); + if (firstWaitOpIter != waitOps.end() && + llvm::all_of(iterArg.getUsers(), [&](Operation *user) { + assert(forOp->isAncestor(user)); + while (user->getParentOp() != forOp) { + user = user->getParentOp(); + } + return (*firstWaitOpIter)->isBeforeInBlock(user); + })) { + LDBG("MMAv3 dot can be properly async because it follows a " + "warp_group_dot_wait " + "{pendings=0}.\n" + << " wait: " << *firstWaitOpIter << "\n" + << " dot: " << dotOp); + threadValuesThroughWait(*firstWaitOpIter, {iterArg}); + return iterArgIdx; + } + + LDBG("Can't make dot async because its result from i-1 is used by " + "something other than another MMAv3 dot as the `c` operand."); + return std::nullopt; +} + +// If necessary, insert a dot-wait inside the loop, waiting for the results of +// the properly-async dots from iteration i-1 to complete. (We pipeline to +// depth 2, so there are at most 2 copies of each warp_group_dot in flight at a +// time.) +// +// We can skip inserting the wait if we have a `warp_group_dot_wait +// {pendings=0}` somewhere in the loop. To see why, consider: +// +// warp_group_dot +// warp_group_dot; wait 0 // synchronous dot +// warp_group_dot +// warp_group_dot +// +// In this example, there are three properly-async dots, so we'd normally put +// `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer +// pending async dots". But note that when this iteration of the loop +// completes, there are only *two* pending async dots from this iteration, so +// this wait would do nothing. This is true in general, no matter where the +// `wait 0` appears. +static void insertAsyncWarpGroupDotWaitInLoop( + scf::ForOp forOp, + const llvm::MapVector &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + if (llvm::any_of(forOp.getBody()->getOps(), + [](auto wait) { return wait.getPendings() == 0; })) { + return; + } + + // Insert waits before the users of the properly async dots other than loop + // yield. + for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) { + DenseMap> blockToUses; + for (auto &use : asyncDot->getUses()) { + if (auto yieldOp = dyn_cast(use.getOwner())) { + continue; + } + + auto block = use.getOwner()->getBlock(); + blockToUses[block].push_back(&use); + } + + for (auto [block, uses] : blockToUses) { + // Insert a wait before the first use in the block + std::sort(uses.begin(), uses.end(), [](OpOperand *lhs, OpOperand *rhs) { + Operation *lhsOp = lhs->getOwner(); + Operation *rhsOp = rhs->getOwner(); + return lhsOp->isBeforeInBlock(rhsOp); + }); + + // If a wgmma uses the same accumulator registers, it will be implicitly + // pipelined by the hardware and doesn't need a wait. + auto firstUse = + std::find_if_not(uses.begin(), uses.end(), [](OpOperand *operand) { + return (isa(operand->getOwner()) && + operand->getOperandNumber() == 2); + }); + if (firstUse == uses.end()) { + continue; + } + + OpBuilder builder((*firstUse)->getOwner()); + auto newWait = ttng::WarpGroupDotWaitOp::create( + builder, asyncDot->getLoc(), ArrayRef{}, 0); + + SmallVector users; + for (; firstUse != uses.end(); ++firstUse) { + users.push_back((*firstUse)->get()); + } + threadValuesThroughWait(newWait, users); + } + } + + for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) { + // If the dot takes the LHS on registers i, we add a wait for the number + // of properly async dots in the loop minus one. + // This makes sure that the dot will wait until itself from the previous + // iteration has completed, as to avoid rewriting the registers. + if (!rsDotNeedsWait(asyncDot, forOp)) + continue; + + OpBuilder builder(asyncDot); + builder.setInsertionPointAfter(asyncDot); + auto newWait = ttng::WarpGroupDotWaitOp::create( + builder, asyncDot->getLoc(), ArrayRef{}, + properlyAsyncDots.size() - 1); + SmallVector waitOperands = {asyncDot->getResult(0)}; + threadValuesThroughWait(newWait, waitOperands); + } + + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. + // + // (You might want to put the wait at the end of the loop instead of right + // after the last dot, but there could be a load into shmem between the last + // async dot and the end of the loop, and that could clobber memory being used + // by a dot.) + IRRewriter builder(forOp.getContext()); + auto lastAsyncDot = properlyAsyncDots.back().first; + // If the last dot is an RS dot, we don't need to insert a wait + // as we have already inserted a wait(properlyAsyncDots.size() - 1) + if (rsDotNeedsWait(lastAsyncDot, forOp)) { + return; + } + builder.setInsertionPointAfter(lastAsyncDot); + auto wait = ttng::WarpGroupDotWaitOp::create(builder, lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, + properlyAsyncDots.size()); + + // Thread the results of the async dots through the wait. + SmallVector addlWaitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + addlWaitOperands.push_back(asyncDot->getResult(0)); + } + threadValuesThroughWait(wait, addlWaitOperands); +} + +// Convert MMAv3 ttng::WarpGroupDotOps {isAsync = False} (i.e. Hopper wgmma) +// into ttng::WarpGroupDotOps {isAsync = True} and insert +// ttng::WarpGroupDotWaitOps as necessary. +// +// We assume we have space for each dot to be pipelined to depth 2, i.e. each +// dot op in the loop can have at most 2 warp_group_dot ops in flight at once. +// (Each warp_group_dot op usually corresponds to a series of wgmma.async ops.) +void triton::asyncLaunchDots(scf::ForOp forOp) { + LDBG("Original loop:\n" << *forOp); + + // First, change every MMAv3 ttng.warp_group_dot {isAsync=false} + // into ttng.warp_group_dot {isAsync=true}. + // The rest of this function is concerned with inserting + // ttng.warp_group_dot_wait ops in the appropriate places. + // + // We call those dots that don't need to be followed immediately by a `wait 0` + // "properly async", or sometimes just "async". + // + // For each dot, determine whether it can be properly async, or if it needs a + // sync immediately after. If it can be properly async, we know its only use + // is in the loop's `yield` statement; asyncDots maps the op to its index in + // the yield op. + IRRewriter builder(forOp.getContext()); + llvm::MapVector properlyAsyncDots; + for (auto WarpGroupDotOp : forOp.getBody()->getOps()) { + WarpGroupDotOp.setIsAsync(true); + if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) { + properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx; + } else { + builder.setInsertionPointAfter(WarpGroupDotOp); + auto wait = ttng::WarpGroupDotWaitOp::create( + builder, WarpGroupDotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {WarpGroupDotOp.getResult()}; + threadValuesThroughWait(wait, waitOperands); + } + } + + if (properlyAsyncDots.empty()) { + LDBG("No properly async dots."); + return; + } + + // Split RS dots into dots with K = 16 (the instruction size of MMAv3) + // If we split them in nSplit dots, we will be able to keep nSplit-1 dots + // in flight at a time. + // We just do it if there is no wait 0 in the loop, as otherwise the split + // just creates unnecessary commits and arrives. + if (llvm::all_of(forOp.getBody()->getOps(), + [](auto wait) { return wait.getPendings() != 0; })) { + properlyAsyncDots = splitRSDots(properlyAsyncDots); + } + + // Next, insert a wait inside the loop. We pipeline to depth 2, so the third + // iteration's set of asynchronous dots (and their corresponding async copies + // from global to shmem) can't start until the first iteration's set has + // completed. + insertAsyncWarpGroupDotWaitInLoop(forOp, properlyAsyncDots); + + // Finally, insert a wait after the loop, waiting for dots from the final + // iteration of the loop. + SmallVector waitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + waitOperands.push_back(forOp.getResult(iterArgIdx)); + } + // Wait until there are 0 outstanding async dot ops. + builder.setInsertionPointAfter(forOp); + auto WarpGroupDotWaitAfterLoop = ttng::WarpGroupDotWaitOp::create( + builder, forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(WarpGroupDotWaitAfterLoop, waitOperands); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 0000000000..9d8e42eb77 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,457 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.subview %a[0, 0] [128, 16] +// %a_prefetch = ttg.local_load %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_prefetch_arg, %b, %c +// %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] +// %a_prefetch_next = ttg.local_load %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-prefetch" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPREFETCH +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 32; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK = std::nullopt, + std::optional shapeK = std::nullopt); + + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[1], ret); + for (int i = 2; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + if (isa(curr.getType())) { + auto retType = RankedTensorType::get( + cast(ret.getType()).getShape(), + cast(curr.getType()).getElementType(), + cast(curr.getDefiningOp()->getOperand(0).getType()) + .getEncoding()); + curr.setType(retType); + } + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK, + std::optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + auto rank = shape.size(); + SmallVector offset(rank, 0); + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? rank - 1 : rank - 2; + + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + Value newSmem = triton::gpu::MemDescSubsliceOp::create( + builder, v.getLoc(), + triton::gpu::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()), + v, offset); + + auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + Value prefetchSlice = triton::gpu::LocalLoadOp::create( + builder, v.getLoc(), + RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) { + // Only accepts dotOps encoded as Nvidia MMA v2 or AMD MFMA + auto dstMmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + auto dstMfmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstMfmaEnc && (!dstMmaEnc || dstMmaEnc.getVersionMajor() != 2)) + // Don't rewrite if any other type is found. + return failure(); + dotsInFor.push_back(dotOp); + } + + if (dotsInFor.empty()) + return failure(); + + // TODO: segfault (original for still has uses) + // when used in flash attention that has 2 dots in the loop + if (dotsInFor.size() > 1) + return failure(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + LDBG("Prefetch src: " << *op); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + // NYI for other encodings, for example if we have transpose + // in the chain + if (isa(cvt.getType().getEncoding())) + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + if (op) + LDBG("op: " << *op); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = mlir::dyn_cast(v)) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getTiedLoopInit(arg)->get(); + return Value(); + }; + + auto getYieldOperand = [this](Value v) -> Value { + auto arg = mlir::cast(v); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + auto aType = dot.getA().getType(); + auto bType = dot.getB().getType(); + auto aEnc = + mlir::cast(aType.getEncoding()); + auto bEnc = + mlir::cast(bType.getEncoding()); + int aKWidth = aEnc.getKWidth(); + int bKWidth = bEnc.getKWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape().back(); + + // works better with nvidia tensor cores + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; + + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); + Value aHeaderDef = getIncomingOp(aSmem); + Value bHeaderDef = getIncomingOp(bSmem); + // Only prefetch loop arg + if (aHeaderDef && bHeaderDef) { + dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; + dot2aHeaderDef[dot] = aHeaderDef; + dot2bHeaderDef[dot] = bHeaderDef; + dot2aLoopArg[dot] = aSmem; + dot2bLoopArg[dot] = bSmem; + dot2aYield[dot] = getYieldOperand(aSmem); + dot2bYield[dot] = getYieldOperand(bSmem); + } + } + } + + return success(); +} + +void Prefetcher::emitPrologue() { + OpBuilder builder(forOp); + + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aPrefetched = + generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); + Value bPrefetched = + generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getA()] = aPrefetched; + operand2headPrefetch[dot.getB()] = bPrefetched; + } +} + +scf::ForOp Prefetcher::createNewForOp() { + OpBuilder builder(forOp); + + SmallVector loopArgs; + for (auto v : forOp.getInitArgs()) + loopArgs.push_back(v); + for (triton::DotOp dot : dots) { + loopArgs.push_back(operand2headPrefetch[dot.getA()]); + loopArgs.push_back(operand2headPrefetch[dot.getB()]); + } + + auto newForOp = + scf::ForOp::create(builder, forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), loopArgs); + + builder.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // The insertion point should be placed before the yield op + auto setInsertionPointBeforeYield = [](OpBuilder &builder, + scf::ForOp newForOp) { + if (newForOp.getBody()->mightHaveTerminator()) { + builder.setInsertionPoint(newForOp.getBody()->getTerminator()); + } else { + builder.setInsertionPointToEnd(newForOp.getBody()); + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + // If we're currently trying to sink a prefetched dot, we need to stop + // sinking it (by resetting the insertion point to the end) if we find + // control flow, or anything that depends on the dot op. + if (op.getNumRegions() > 0) { + setInsertionPointBeforeYield(builder, newForOp); + } + for (auto operand : op.getOperands()) { + if (auto def = operand.getDefiningOp()) { + auto dot = dyn_cast(def); + if (dot && dots.contains(dot)) { + setInsertionPointBeforeYield(builder, newForOp); + } + } + } + Operation *newOp = builder.clone(op, mapping); + auto dot = dyn_cast(&op); + if (dot && dots.contains(dot)) { + Attribute dotEncoding = dot.getType().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.getA())) + firstDot->setOperand( + 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.getB())) + firstDot->setOperand( + 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.getA().getType().getShape().back() - prefetchWidth; + Operation *prevDot = firstDot; + if (kRem == 0) { + // There is only one dot while prefetchWidth == kSize so delay issuing + // it. Meanwhile, newOp should be set to firstDot to make sure the dot + // result is updated to yield. + builder.setInsertionPoint(prevDot); + newOp = firstDot; + } + + while (kRem != 0) { + // int64_t kShape = largestPow2(kRem); + int64_t kShape = prefetchWidth; + auto insertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(prevDot); + Value aRem = + generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); + Value bRem = + generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); + builder.restoreInsertionPoint(insertionPoint); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + if (kRem == 0) { + // We want to delay issuing the last dot as long as possible, ideally + // until after the prefetch. To accomplish this, set the insertion + // point above the dot. If we find anything dependent on the dot (at + // the top of this loop), we resume inserting after it. + builder.setInsertionPoint(prevDot); + } + } + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); + } + // Update ops of yield + builder.setInsertionPointToEnd(newForOp.getBody()); + if (!yieldValues.empty()) + scf::YieldOp::create(builder, yieldOp.getLoc(), yieldValues); + return newForOp; +} + +} // anonymous namespace + +struct PrefetchPass : public impl::TritonGPUPrefetchBase { + void runOnOperation() override { + + // Canonicalize convert ops to make the pattern matching easier. + RewritePatternSet cleanUpPatterns(&getContext()); + triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, + &getContext()); + if (mlir::applyPatternsGreedily(getOperation(), std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp new file mode 100644 index 0000000000..3f9b247d05 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -0,0 +1,68 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREDUCEDATADUPLICATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUReduceDataDuplicationPass + : public impl::TritonGPUReduceDataDuplicationBase< + TritonGPUReduceDataDuplicationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcEncoding = srcType.getEncoding(); + if (isa(srcEncoding)) + return; + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!dstDotOp) + return; + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; + auto order = getOrderForMemory(srcType); + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); + auto tmpType = triton::gpu::MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SwizzledSharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), order, + triton::gpu::getCGALayout(srcEncoding), srcType.getElementType()), + sharedMemorySpace); + auto tmp = triton::gpu::LocalAllocOp::create(builder, cvtOp.getLoc(), + tmpType, cvtOp.getSrc()); + auto newConvert = triton::gpu::LocalLoadOp::create( + builder, cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 0000000000..a55b64468b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1701 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir::triton::gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// Large rematerialization slices can lead to steep compile-time blowups. +// Bail out once the backward slice exceeds this cap and keep the conversion. +constexpr unsigned kMaxRematSliceSize = 256; + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Return the original value mapped to the new desired encoding. + Value getRewrittenValue(Value value); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + // Get the remat'ed value in the given encoding, if one already exists and + // is different then the layout conversion root. + Value getRematValue(Value value, Attribute encoding) const { + return rematMapping.lookup({value, encoding}); + } + + void cleanup(); + bool backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + // TODO: Merge the three hoistConvert*(); functions as they are duplicate code + void hoistConvertDotOperand(); + void hoistConvertDotOperand(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void hoistConvertIntoConditionals(); + void hoistConvertIntoConditionals(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + + LogicalResult + getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, + SetVector &slice, + DenseMap &layout, + std::function stopPropagation, + unsigned maxSliceSize = 0); + + LogicalResult getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr, + unsigned maxSliceSize = 0); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; + DominanceInfo domInfo; + PostDominanceInfo postDomInfo; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + if (auto gatherOp = dyn_cast(op)) + return gatherOp.getEfficientLayout(); + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + + return false; +} + +void LayoutPropagation::initAnchorLayout() { + auto addAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + addAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + addAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + Attribute dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (user->hasTrait()) { + unsigned opIndex = use.getOperandNumber(); + Value result = user->getResult(opIndex); + setEncoding(result, info, changed, user); + continue; + } + if (auto gatherOp = dyn_cast(user)) { + // Propagate the layout through the indices only, and if the layout does + // not have an efficient layout set. + if (!gatherOp.getEfficientLayout() && + &use == &gatherOp.getIndicesMutable()) { + setEncoding(gatherOp.getResult(), info, changed, user); + continue; + } + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + DBGS() << "changed: " << changed.size() << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + std::deque queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.front(); + queue.pop_front(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getRewrittenValue(Value value) { + auto tensorType = dyn_cast(value.getType()); + if (!tensorType) + return value; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + return value; + } + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + return value; + return rewriteMapping.at({value, encodingPicked}); +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue = getRewrittenValue(value); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = tensorType.cloneWithEncoding(encoding); + Value converted = ConvertLayoutOp::create(rewriter, value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + Attribute operandEnc; + if (op->getNumOperands() > 0) { + for (auto operand : op->getOperands()) { + auto ty = + dyn_cast(getRewrittenValue(operand).getType()); + if (!ty) + continue; + auto enc = ty.getEncoding(); + if (inferDstEncoding(op, enc) == encoding) { + operandEnc = enc; + break; + } + } + if (!operandEnc) + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = origType.cloneWithEncoding(encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = origType.cloneWithEncoding(it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + scf::WhileOp::create(rewriter, whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = origType.cloneWithEncoding(encoding); + } + auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = tensorType.cloneWithEncoding(encoding); + auto cvt = ConvertLayoutOp::create(rewriter, op->getLoc(), newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = tensorType.cloneWithEncoding(encoding); + auto cvt = ConvertLayoutOp::create(rewriter, op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op) || + op->hasTrait()) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (auto gather = dyn_cast(op)) + return !gather.getEfficientLayout(); + + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (Value remat = getRematValue(v, layoutIt->second)) { + mapping.map(v, remat); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = mlir::topologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = oldType.cloneWithEncoding(it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + scf::YieldOp::create(builder, op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = tensorType.cloneWithEncoding(layout[op->getResult(0)]); + auto cvt = ConvertLayoutOp::create(builder, op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = + cast(old.getType()).cloneWithEncoding(it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult LayoutRematerialization::getConvertBackwardSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation, unsigned maxSliceSize) { + // Allow re-using existing conversions for a value. Check dominance of any + // reusable materializations against the root value. This is sufficient + // because the conversions are processed in post-order. + auto getExistingConversion = [&](OpOperand &value, Attribute encoding) { + Value remat = getRematValue(value.get(), encoding); + if (!remat) + return Value(); + // `value` can be replaced with an existing rematerialization if it + // dominates the current use of value. + Operation *user = value.getOwner(); + if (domInfo.properlyDominates(remat, user)) { + return remat; + } + // FIXME: If the current user is a conversion, then we know it will become + // a no-op when its operand is replaced with `remat`, but we need to check + // that its users are all dominated by `remat` so the IR is valid. + // if (isa(user) && remat.getDefiningOp() && + // domInfo.properlyDominates(user, remat.getDefiningOp())) { + // for (Operation *op : user->getUsers()) { + // if (!domInfo.dominates(remat, op)) + // return Value(); + // } + // return remat; + // } + return Value(); + }; + + return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion, + maxSliceSize); +} + +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &sliceArg, + DenseMap &layoutArg, + std::function stopPropagation, unsigned maxSliceSize) { + // Operate on copies of the input, we do not want to modify them unless we + // have succeeded. + auto slice = sliceArg; + auto layout = layoutArg; + LogicalResult result = getConvertBackwardSlice( + root, rootEncoding, slice, layout, stopPropagation, maxSliceSize); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + sliceArg = std::move(slice); + layoutArg = std::move(layout); + return success(); +} + +bool LayoutRematerialization::backwardRematerialization() { + bool changed = false; + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } else { + changed = true; + } + } + return changed; +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertIntoConditionals() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertIntoConditionals(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +static bool isExpensiveMathOp(Operation *op) { + // These operations are either multiple instructions or have throughput + // lower than 16 according to the arithmetic instructions table in: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions + return isa(op); +} + +static int64_t getByteCount(Value result, int64_t minElementCount = 0, + int64_t minBitWidth = 0) { + int64_t elementCount = 0; + int64_t dtypeBitWidth = 0; + if (auto tensorTy = dyn_cast(result.getType())) { + elementCount = tensorTy.getNumElements(); + auto elemType = tensorTy.getElementType(); + if (elemType.isIntOrFloat()) { + dtypeBitWidth = elemType.getIntOrFloatBitWidth(); + } + } + if (elementCount < minElementCount) { + elementCount = minElementCount; + } + if (dtypeBitWidth < minBitWidth) { + dtypeBitWidth = minBitWidth; + } + return (elementCount * dtypeBitWidth) >> 3; +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp.getSrc(); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. Make sure it dominates the current conversion. + Value newV = getRematValue(oldV, targetType.getEncoding()); + if (newV && domInfo.properlyDominates(newV, convertOp)) { + // Replace it with the remat'ed value. + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, + nullptr, kMaxRematSliceSize); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + // 2. Determine whether rematerialisation is beneficial. + + // Identify all operations in the slice + SetVector sliceOps; + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + sliceOps.insert(op); + } + } + + // Compute single-use operations + DenseMap isSingleUse; + std::function isOpSingleUse; + isOpSingleUse = [&](Operation *op) -> bool { + // lookup in memoization array: + auto it = isSingleUse.find(op); + if (it != isSingleUse.end()) { + return it->second; + } + + bool singleUse = true; + + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (user == convertOp) { + continue; + } + if (sliceOps.contains(user)) { + if (!isOpSingleUse(user)) { + singleUse = false; + break; + } + } else { + singleUse = false; + break; + } + } + if (!singleUse) { + break; + } + } + + // insert into memoization array: + isSingleUse[op] = singleUse; + return singleUse; + }; + + // Measure the number of bytes that we're manipulating with the + // ConvertLayoutOp. We pessimistically assume that we round-trip + // through shared memory and that we cannot vectorise sub-register + // loads/stores, so we set a minimum element count of 32 (the warp + // size and number of shared memory banks) and minimum bitwidth of + // 32 (the width per bank of the shared memory load/store unit). + int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32); + + // We measure costs in standardised milli-SM-cycles. The smem load + // and store each cost 8 * convertLayoutBytes, and then we double + // it to account for extra cost due to synchronisation. + int64_t convertLayoutCost = 32 * convertLayoutBytes; + int64_t rematerialisationCost = 0; + + // Evaluate single-use status for every operation in slice + for (Operation *op : sliceOps) { + auto dialect = op->getDialect(); + if (isOpSingleUse(op)) { + // when we rematerialise, this operation does not get duplicated + // so it does not contribute to our cost model: + continue; + } else if (isa(op)) { + // special-case: arith.constant has zero cost + continue; + } else if (isa(op) || isa(op)) { + // optimistically assume L1-cached: + for (Value result : op->getResults()) { + rematerialisationCost += 8 * getByteCount(result); + } + } else if (isa(dialect)) { + // this is an arithmetic operation; we distinguish between cheap + // operations (such as floating point add/mul which can be fused + // as halves of a single-cycle FMA instruction) and expensive + // operations which use the special function unit and/or involve + // multiple instructions. + int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1; + for (Value result : op->getResults()) { + rematerialisationCost += multiplier * getByteCount(result); + } + } else if (isa(op)) { + // Reduce op introduce much cost. + auto reduceOp = dyn_cast(op); + ReduceOpHelper helper(reduceOp); + if (!helper.isAssociative()) { + // We shouldn't rematerize a no associative reduce op if it has multiple + // use chain. + LDBG(" skipped rematerialization due to non-associative reduce in the " + "slice"); + return; + } + rematerialisationCost += helper.getIntraWarpSizeWithUniqueData(); + rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData(); + } + } + + LLVM_DEBUG({ + DBGS() << " convert layout cost: " << convertLayoutCost << "\n"; + DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n"; + }); + + if (rematerialisationCost > convertLayoutCost) { + LDBG(" skipped rematerialization due to higher cost"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +void LayoutRematerialization::hoistConvertDotOperand() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertDotOperand(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertDotOperand( + ConvertLayoutOp convertOp) { + auto targetType = convertOp.getType(); + // The pass is targeted to MMA dot operands + + auto canBePipelined = [&](ConvertLayoutOp convertOp) { + // FIXME: Check that the parent is a for loop + auto parent = convertOp->getParentOp(); + if (!parent) + return false; + + // Find all the dot-like ops in the for loop that have a dot operand + // encoding on the lhs and check if any of them post-dominates the load + + // cvt + SmallVector dotLikeOps; + parent->walk([&](Operation *op) { + if (!isa(op)) + return; + auto opType = dyn_cast(op->getOperand(0).getType()); + if (!opType) + return; + auto dotEnc = dyn_cast(opType.getEncoding()); + if (!dotEnc) + return; + if (isa(dotEnc.getParent())) + dotLikeOps.push_back(op); + }); + if (dotLikeOps.empty()) + return false; + return llvm::any_of(dotLikeOps, [&](Operation *dot) { + return postDomInfo.postDominates(dot, convertOp); + }); + }; + + // We move convert #dot_operand next to their loads. This is done + // so that it's then easy to pipeline these loads + if (!canBePipelined(convertOp)) + return; + + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now + auto noDataMovement = [](Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isa( + op) || + isView(op); + }; + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads + auto stop = std::not_fn(noDataMovement); + + SetVector slice; + DenseMap layout; + // Set-up the conversion "cache" + LogicalResult result = getConvertBackwardSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop); + if (result.failed()) + return; + + IRMapping mapping; + OpBuilder builder(convertOp.getContext()); + SetVector innerSlice; + for (Value v : slice) { + if (!v.getDefiningOp()) { + LLVM_DEBUG( + { DBGS() << " Block arguments not supported. Got " << v << "\n"; }); + return; + } + + // We expect the leaves of the slice to be Load, DescriptorLoad or + // arith::Constant This could be generalised if necessary + if (!isa(v.getDefiningOp())) { + auto op = v.getDefiningOp(); + if (isa(op) || noDataMovement(op)) { + innerSlice.insert(v); + continue; + } else { + LLVM_DEBUG({ + DBGS() << " Leaves must be Load, DescriptorLoad or Constant. Got " + << v << "\n"; + }); + return; + } + } + Operation *loadOp = v.getDefiningOp(); + builder.setInsertionPointAfter(loadOp); + auto type = dyn_cast(loadOp->getResult(0).getType()); + if (!type) + continue; + auto newType = type.cloneWithEncoding(layout[loadOp->getResult(0)]); + auto newConvertOp = ConvertLayoutOp::create(builder, convertOp.getLoc(), + newType, loadOp->getResult(0)); + mapping.map(loadOp->getResult(0), newConvertOp.getResult()); + } + + if (innerSlice.empty()) { + return; + } + + LLVM_DEBUG({ + DBGS() << " Hoisting " << convertOp << '\n'; + for (Value v : innerSlice) + DBGS() << " " << v << '\n'; + }); + + rewriteSlice(innerSlice, layout, convertOp, mapping); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(cast(fpToFpOp.getType())); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, + isExtOrBroadcastOp, kMaxRematSliceSize); + if (result.failed()) + return; + + Operation *extOrBroadcastOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op || !isExtOrBroadcastOp(op)) + continue; + + Attribute srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + + // If we can rematerialize the rest of the ext slice we can ignore this ext + // as it won't need a convert. + if (succeeded(getRematerializableSlice(op->getOpOperand(0), srcEncoding, + slice, layout, nullptr, + kMaxRematSliceSize))) + continue; + + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcastOp != nullptr) + return; + extOrBroadcastOp = op; + } + + if (extOrBroadcastOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcastOp->getResult(0)]; + Attribute srcEncoding = inferSrcEncoding(extOrBroadcastOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcastOp); + auto tensorType = + cast(extOrBroadcastOp->getOperand(0).getType()); + auto newType = tensorType.cloneWithEncoding(srcEncoding); + auto newConvertOp = ConvertLayoutOp::create( + builder, convertOp.getLoc(), newType, extOrBroadcastOp->getOperand(0)); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcastOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcastOp->getResult(0).getType()); + Type newExtOrBroadcastType = + oldExtOrBroadcastType.cloneWithEncoding(dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcastType); + IRMapping mapping; + mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcastOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void LayoutRematerialization::hoistConvertIntoConditionals( + ConvertLayoutOp convertOp) { + // Take the backward slice of tensor dependencies rooted at the conversion, + // stopping at conditionals. This subslice is used to initialize the analysis. + SetVector slice; + DenseMap layout; + auto isIfOp = [](Operation *op) { return isa(op); }; + if (failed(getRematerializableSlice(convertOp.getSrcMutable(), + convertOp.getType().getEncoding(), slice, + layout, isIfOp, kMaxRematSliceSize))) + return; + + // These are the conditional edges above which conversions should be hoisted. + // The value represents the `scf.if` op result and the operand represents the + // edge into one of the branches. + SmallVector> hoistAbove; + + // The list of `scf.if` op results in the slice that are not rematerializable. + // Hoisting is terminated at these values. + SmallVector terminals; + + // This loop recurses through the subslices of the backwards dependencies, so + // re-query the size of `slice`. + for (unsigned i = 0; i != slice.size(); ++i) { + Value v = slice[i]; + auto ifOp = v.getDefiningOp(); + if (!ifOp) + continue; + + Attribute rootLayout = layout.at(v); + unsigned resIdx = cast(v).getResultNumber(); + + // Take the backward slice along each branch. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + OpOperand &thenRes = thenYield.getResultsMutable()[resIdx]; + OpOperand &elseRes = elseYield.getResultsMutable()[resIdx]; + + auto newSlice = slice; + auto newLayout = layout; + + LogicalResult thenResult = getRematerializableSlice( + thenRes, rootLayout, newSlice, newLayout, isIfOp, kMaxRematSliceSize); + LogicalResult elseResult = getRematerializableSlice( + elseRes, rootLayout, newSlice, newLayout, isIfOp, kMaxRematSliceSize); + + // If propagation across both edges of this conditional succeeded, then we + // don't need to hoist across it. Merge into the current slice. + if (succeeded(thenResult) && succeeded(elseResult)) { + slice = std::move(newSlice); + layout = std::move(newLayout); + continue; + } + + // If propagation across both edges failed, then this conditional + // terminates backwards rematerialization. + if (failed(thenResult) && failed(elseResult)) { + terminals.push_back(cast(v)); + continue; + } + + // Only hoist into conditionals inside loops. The assumption is that an if + // inside a loop executes fewer than the total number of loop iterations, + // making this hoist profitable. + if (!isa(ifOp->getParentOp())) { + terminals.push_back(cast(v)); + continue; + } + + slice = std::move(newSlice); + layout = std::move(newLayout); + // The layout conversion can be rematerialized along one edge but not the + // other. We can hoist the conversion into the other branch. Push this + // into the subslice list for analysis. + if (succeeded(thenResult)) { + hoistAbove.emplace_back(v, &elseRes); + } else { + hoistAbove.emplace_back(v, &thenRes); + } + } + + // Exit early if there is nothing to do. + if (hoistAbove.empty()) + return; + + // Rematerialize failed hoists right before the condtional, and hoist those + // that succeeded into the branch and then rewrite the slice. + IRMapping mapping; + auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) { + auto tensorType = cast(v.getType()); + auto newType = tensorType.cloneWithEncoding(encoding); + Value newCvt = ConvertLayoutOp::create(b, convertOp.getLoc(), newType, v); + + mapping.map(v, newCvt); + slice.remove(v); + }; + for (Value v : terminals) { + OpBuilder b(v.getContext()); + b.setInsertionPointAfter(v.getDefiningOp()); + hoistRemat(b, v, layout.at(v)); + } + for (auto [result, edge] : hoistAbove) { + OpBuilder b(edge->getOwner()); + hoistRemat(b, edge->get(), layout.at(result)); + } + rewriteSlice(slice, layout, convertOp, mapping); +} + +bool backwardRematerialization(ModuleOp module) { + bool changed = false; + module.walk([&](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + changed |= layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); + return changed; +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertIntoConditionals(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertDotOperand(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + // Cleanup convert ops. + void cleanupConvertOps() { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + cleanupConvertOps(); + + bool changed = false; + do { + changed = false; + // 2. For remaining convert ops, try to rematerialize the slice of + // producer operation to avoid having to convert. + changed = backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + } while (changed); + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp new file mode 100644 index 0000000000..456a40f48d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -0,0 +1,178 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static bool willIncreaseRegisterPressure(Operation *op) { + if (isa(op)) + return true; + auto cvt = dyn_cast(op); + if (!cvt) + return false; + if (mlir::isa( + cvt.getType().getEncoding())) + return true; + return false; +} + +// Return true if it has side effects that are either unknown or writes. +static bool hasWriteSideEffect(Operation *op) { + auto effects = getEffectsRecursively(op); + if (!effects) + return false; + return llvm::any_of(*effects, [](MemoryEffects::EffectInstance effect) { + return !isa(effect.getEffect()); + }); +} + +// Return true if there is a write side effect on any path between start and end +// ops. This assumes start dominates end. +static bool crossWriteSideEffectingOp(Operation *start, Operation *end) { + auto ancestor = start->getBlock()->findAncestorOpInBlock(*end); + // Couldn't find an ancestor in the same block, conservatively assume true. + if (!ancestor) + return true; + Operation *nextOp = start->getNextNode(); + while (nextOp) { + if ((hasWriteSideEffect(nextOp))) + return true; + if (nextOp == ancestor) + return false; + nextOp = nextOp->getNextNode(); + } + assert(false && "op doesn't dominate other"); + return true; +} + +class TritonGPUReorderInstructionsPass + : public impl::TritonGPUReorderInstructionsBase< + TritonGPUReorderInstructionsPass> { +public: + TritonGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(users, [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + mlir::DominanceInfo dom(m); + // sink conversion after the last dealloc + // before the first use ancestor in its block + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + auto end = op->getBlock()->end(); + for (; curr != end && &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + lhs->moveAfter(rhs); + }; + m.walk([&](Operation *op) { + if (!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if (std::distance(user_begin, user_end) != 1) + return; + if (user_begin->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, *user_begin}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + // Move alloc(load) immediately after dependent load + m.walk([&](triton::gpu::LocalAllocOp op) { + if (!op.getSrc()) + return; + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + // Don't hoist alloc if the src is a scalar as this may increase smem + // pressure for no benefits. + if (isa(argOp)) + return; + moveAfter(op, argOp); + }); + // Move transpositions just after their definition + opToMove.clear(); + m.walk([&](triton::TransposeOpInterface op) { + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move `dot` operand so that conversions to opIdx=1 happens after + // conversions to opIdx=0 + m.walk([&](triton::gpu::LocalLoadOp op) { + auto dstEncoding = mlir::dyn_cast( + op.getType().getEncoding()); + if (!dstEncoding) + return; + int opIdx = dstEncoding.getOpIdx(); + if (opIdx != 1) + return; + if (!op->hasOneUse()) + return; + auto dotUser = dyn_cast(*op->user_begin()); + if (!dotUser) + return; + auto AOp = + dotUser.getOperand(0).getDefiningOp(); + if (!AOp) + return; + // Check that the conversion to OpIdx=1 happens before and can be moved + // after the conversion to OpIdx=0. + if (!dom.dominates(op.getOperation(), AOp.getOperation())) + return; + if (crossWriteSideEffectingOp(op, AOp)) + return; + moveAfter(op, AOp); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 0000000000..a48bc1bfdd --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,1759 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace mlir { + +using namespace triton; + +static bool isPassthroughWaitLikeOp(Operation *op) { + return op->hasTrait(); +} + +static FailureOr +rebuildPassthroughWaitLikeOp(OpBuilder &builder, Operation *op, + ValueRange newOperands) { + auto typeInfer = dyn_cast(op); + if (!typeInfer || op->getNumRegions() != 0 || op->getNumSuccessors() != 0) + return failure(); + + SmallVector newTypes; + if (failed(typeInfer.inferReturnTypes( + op->getContext(), op->getLoc(), newOperands, op->getAttrDictionary(), + op->getPropertiesStorage(), op->getRegions(), newTypes))) + return failure(); + + OperationState state(op->getLoc(), op->getName(), newOperands, newTypes, + op->getAttrs()); + return builder.create(state); +} + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + Type eltType, int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else if (version == 5) { + unsigned m = shape[0] >= 128 ? 128 : 64; + // Right now default to distributing along N. TODO: For cases where we have + // dot followed by reduction we need to be able to distribute along M. + // if (numWarps > 4) + // m = 64; + unsigned n = shape[1] >= 256 ? 256 : shape[1]; + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + return {m, n, k}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector +getOrderFromContiguity(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::reverse(ret.begin(), ret.end()); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + ArrayRef shapePerCTA) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + LDBG("elemNumBytes: " << elemNumBytes + << ", divisibility: " << maxMultipleBytes + << ", contig: " << valInfo.getContiguity(order[0]) + << ", alignment: " << alignment); + return currPerThread; +} + +bool isView(Operation *op) { + return isa(op); +} + +bool isNoop(Operation *op) { + if (isa(op)) + return true; + if (auto cvt = dyn_cast(op)) { + // The conversion op is a noop if the conversion layout is trivial + return minimalCvtLayout(cvt.getSrc().getType(), + cvt.getResult().getType()) == LinearLayout::empty(); + } + return false; +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static Attribute inferDstEncoding(triton::ReduceOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferDstEncoding(triton::ExpandDimsOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getLhs().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferDefaultJoinOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(triton::ReduceOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferSrcEncoding(triton::ExpandDimsOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + auto shape = op.getResult().getType().getShape(); + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, shape, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + auto shape = op.getOutLHS().getType().getShape(); + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferDefaultJoinOpEncoding(dstEnc, srcEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(GatherOp op, Attribute dstEnc) { + // The index encoding is the same as the output encoding. + return dstEnc; +} + +static Attribute inferTransOpDstEncoding(Attribute srcEnc, + ArrayRef shape, + ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, shape, order, retEncoding, + /*loc=*/{}))) { + return retEncoding; + } + return {}; +} + +static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), srcEnc, dstEnc, + /*fwdInference*/ true, std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) { + Attribute srcEnc; + auto shape = op.getType().getShape(); + if (succeeded( + dstEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), dstEnc, srcEnc, + /*fwdInference*/ false, std::nullopt))) { + return srcEnc; + } + return {}; +} + +static Attribute inferDstEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + return inferTransOpDstEncoding( + encoding, cast(op.getSrc().getType()).getShape(), + op.getOrder()); +} + +static Attribute inferSrcEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + auto shape = cast(op->getResult(0).getType()).getShape(); + return inferTransOpDstEncoding(encoding, shape, + triton::inversePermutation(op.getOrder())); +} + +static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return {}; + + Attribute dstEnc; + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, + /*loc=*/std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static Attribute inferDstEncoding(GatherOp op, Attribute encoding) { + // The output encoding is the same as the index encoding. + // FIXME: This assumes `encoding` is the index encoding, which can be + // different than the source encoding. + return encoding; +} + +static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +static bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +Attribute inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return {}; + } + + if (isa(op)) + return {}; + + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op) || + isPassthroughWaitLikeOp(op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferSrcEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferSrcEncoding(fp4ToFp, encoding); + + return {}; +} + +Attribute inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return {}; + } + if (isa(op)) + return {}; + + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op) || + isPassthroughWaitLikeOp(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferDstEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferDstEncoding(fp4ToFp, encoding); + + return {}; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + reshapeDstType.cloneWithEncoding(targetEncoding); + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = + scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(), + loop.getUpperBound(), loop.getStep(), operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto [result, value] : replacements) { + result.replaceAllUsesWith(value); + } + return newForOp; +} + +scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + scf::ForOp newLoop = + replaceForOpWithNewSignature(rewriter, loop, newIterOperands); + // Save the caller from insertion point invalidation. + if (rewriter.getInsertionPoint() == loop->getIterator()) + rewriter.setInsertionPoint(newLoop); + loop.erase(); + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature( + OpBuilder &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInits()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + + // Result and operand types + SmallVector resultTypes; + SmallVector argsTypesBefore; + for (auto res : loop.getResults()) + resultTypes.push_back(res.getType()); + for (auto type : newResultTypes) + resultTypes.push_back(type); + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + scf::WhileOp newLoop = + scf::WhileOp::create(rewriter, loop.getLoc(), resultTypes, operands); + newLoop->setAttrs(loop->getAttrs()); + + SmallVector bbArgLocsBefore(argsTypesBefore.size(), loop.getLoc()); + SmallVector bbArgLocsAfter(resultTypes.size(), loop.getLoc()); + rewriter.createBlock(&newLoop.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newLoop.getAfter(), {}, resultTypes, bbArgLocsAfter); + + // Copy regions + for (int i = 0; i < loop.getNumRegions(); ++i) + newLoop->getRegion(i).front().getOperations().splice( + newLoop->getRegion(i).front().getOperations().begin(), + loop->getRegion(i).front().getOperations()); + + // Remap arguments + for (auto [oldArg, newArg] : llvm::zip( + loop.getBeforeArguments(), newLoop.getBeforeArguments().take_front( + loop.getBeforeArguments().size()))) + oldArg.replaceAllUsesWith(newArg); + for (auto [oldArg, newArg] : llvm::zip(loop.getAfterArguments(), + newLoop.getAfterArguments().take_front( + loop.getAfterArguments().size()))) + oldArg.replaceAllUsesWith(newArg); + + // Stack the new results + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature(OpBuilder &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newWhileOp = replaceWhileOpWithNewSignature( + rewriter, loop, newIterOperands, newResultTypes, replacements); + for (auto &kv : replacements) { + std::get<0>(kv).replaceAllUsesWith(std::get<1>(kv)); + } + return newWhileOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + OpBuilder &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = scf::IfOp::create(rewriter, ifOp.getLoc(), resultTypes, + ifOp.getCondition()); + newIf->setAttrs(ifOp->getAttrs()); + + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); + scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + scf::YieldOp::create(builder, yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +scf::IfOp replaceIfOpWithNewSignature(OpBuilder &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, newResultTypes, replacements); + for (auto &kv : replacements) + std::get<0>(kv).replaceAllUsesWith(std::get<1>(kv)); + return newIfOp; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = origType.cloneWithEncoding(argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be performed by reordering registers. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return cvtReordersRegisters(convertOp.getSrc().getType(), + convertOp.getType()); +} + +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation, + std::function getExistingConversion, + unsigned maxSliceSize) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](OpOperand &operand, Attribute encoding) { + auto x = std::make_pair(&operand, encoding); + if (!seen.insert(x).second) { + return; // Already enqueued, skip + } + queue.push_back(x); + }; + enqueue(root, rootEncoding); + + auto updateLayout = [&](Value value, Attribute encoding) { + assert((isa(value.getType()))); + slice.insert(value); + Attribute &existing = layout[value]; + if (existing && existing != encoding) + return failure(); + existing = encoding; + return success(); + }; + + while (!queue.empty()) { + auto [currentValueUse, encoding] = queue.back(); + Value currentValue = currentValueUse->get(); + queue.pop_back(); + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op/while op/ws op results for now. + // TODO: enable this based on needs. + auto defOp = currentValue.getDefiningOp(); + if (isa_and_nonnull(defOp)) + return failure(); + if (failed(updateLayout(currentValue, encoding))) + return failure(); + if (maxSliceSize && slice.size() > maxSliceSize) + return failure(); + + // If there is already an existing conversion to the target layout, we don't + // need to propagate to the operands. + // Note that this is per-use rather than per-value, so if another use fails + // the getExistingConversion check, we may still traverse the operands. + if (getExistingConversion && + getExistingConversion(*currentValueUse, encoding)) { + continue; + } + + if (auto ifOp = currentValue.getDefiningOp()) { + if (stopPropagation && stopPropagation(ifOp)) + continue; + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx); + OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx); + + enqueue(thenValue, encoding); + enqueue(elseValue, encoding); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + if (failed(updateLayout(result, encoding))) + return failure(); + } + if (isFreeConvert(definingOp)) { + enqueue(definingOp->getOpOperand(0), encoding); + continue; + } + if (canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + if (auto gather = dyn_cast(definingOp)) { + // Specially handle gather since its transfer function only applies + // between its index operand and result. + auto srcEncoding = inferSrcEncoding(gather, encoding); + if (!srcEncoding) + return failure(); + enqueue(gather.getIndicesMutable(), srcEncoding); + continue; + } + for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) { + Attribute srcEncoding; + if (auto upcast = + dyn_cast(definingOp)) { + srcEncoding = upcast.inferSrcEncoding(i, encoding); + } else { + srcEncoding = inferSrcEncoding(definingOp, encoding); + } + if (!srcEncoding) + return failure(); + // If the infered layout matches the original one we don't need to keep + // propagating. + if (auto operandType = + dyn_cast(operand.get().getType())) { + if (srcEncoding == operandType.getEncoding()) + continue; + } + enqueue(operand, srcEncoding); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + enqueue(*initOperand, encoding); + enqueue(yieldOperand, encoding); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = arith::ConstantIntOp::create(b, loc, en.value(), 32); + multiDim[en.index()] = arith::RemSIOp::create(b, loc, remained, dimSize); + remained = arith::DivSIOp::create(b, loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = arith::ConstantIntOp::create(b, loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = arith::ConstantIntOp::create(b, loc, dimShape, 32); + linear = arith::AddIOp::create( + b, loc, arith::MulIOp::create(b, loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + assert(targetAttr && "Expected a target attribute on the module operation"); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +std::optional getAMDArch(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + if (!targetAttr) { + LDBG("Expected a target attribute on the module operation"); + return {}; + } + + StringRef ref = targetAttr.strref(); + if (!ref.starts_with("hip:")) { + LDBG("expected target attribute to be prefixed with \"hip:\""); + return {}; + } + + return ref.drop_front(4); // drop the "hip:" +} + +inline ttg::SwizzledSharedEncodingAttr +swizzleDotOperandLike(RankedTensorType type, ttg::CGAEncodingAttr cgaLayout) { + // We want to see if the linear layout has the same order as an mma microtile + // of shape (8, 4*kWidth) or (4*kWidth, 8). If so, we return a + // DotOperandEncodingAttr with a tile of this shape This works because + // SwizzledSharedEncodingAttr::get just looks at the microtile to determine + // the swizzling + + auto *ctx = type.getContext(); + auto layout = ttg::toLinearEncoding(type); + auto order = layout.getThreadOrder(); + auto rank = order.size(); + if (rank < 2) { + return {}; + } + int opIdx; + if (ttg::getOrderForDotOperand(0, rank, /*kContig=*/true) == order) { + opIdx = 0; + } else if (ttg::getOrderForDotOperand(1, rank, /*kContig=*/true) == order) { + opIdx = 1; + } else { + return {}; + } + auto kWidth = layout.getContigPerThread()[order[0]]; + SmallVector microtileShape(rank, 1); + microtileShape[order[0]] = 4 * kWidth; + microtileShape[order[1]] = 8; + // All the LinearLayouts contained within LinearEncoidngAttr have order [0, 1, + // 2, ...] + auto repOrder = to_vector(llvm::seq(rank)); + auto tile = ttg::nvidiaMmaTile(ctx, microtileShape, kWidth, order, repOrder); + if (!divideLeft(layout.getLinearLayout(), tile).has_value()) { + return {}; + } + return ttg::SwizzledSharedEncodingAttr::get( + ctx, opIdx, kWidth, type.getShape(), order, cgaLayout, + type.getElementTypeBitWidth(), false); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are incompatible shared +// encodings, set incompatible to true. +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { + ttg::SwizzledSharedEncodingAttr attr; + incompatible = false; + for (Operation *user : val.getUsers()) { + ttg::SwizzledSharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = + dyn_cast(memDesc.getEncoding()); + if (!tempAttr) + return std::nullopt; + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto dstTy = cast(user->getResult(0).getType()); + + // FIXME This may not be correct for multiple CTA, but getCGALayout is NYI + // for LinearEncodingAttr + auto CGALayout = isa(dstTy.getEncoding()) + ? ttg::getCGALayout(srcTy.getEncoding()) + : ttg::getCGALayout(dstTy.getEncoding()); + + if (auto dot = + dyn_cast(dstTy.getEncoding())) { + auto order = getOrderForMemory(srcTy); + unsigned bitWidth = srcTy.getElementTypeBitWidth(); + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dot, srcTy.getShape(), order, CGALayout, bitWidth, + /*needTrans=*/false); + } else { + // Try to see if the layout is like an mma microtile + tempAttr = swizzleDotOperandLike(dstTy, CGALayout); + } + if (!tempAttr) + return std::nullopt; + } + // Check that the shared encodings needed by the users are compatible. + if (attr != nullptr && attr != tempAttr) { + incompatible = true; + return std::nullopt; + } + attr = tempAttr; + } + return attr; +} + +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +static bool skipOperand(Operation *op, unsigned operandNumber) { + if (auto gather = dyn_cast(op)) { + return operandNumber == gather.getXOffsetsMutable().getOperandNumber(); + } + if (auto scatter = dyn_cast(op)) { + return operandNumber == scatter.getXOffsetsMutable().getOperandNumber(); + } + return false; +} + +Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto &opOperand : op->getOpOperands()) { + Value operand = opOperand.get(); + auto tensorType = dyn_cast(operand.getType()); + bool skip = skipOperand(op, opOperand.getOperandNumber()); + if (tensorType && !skip) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(triton::gpu::ConvertLayoutOp::create( + builder, op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), + newArgs, newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = triton::gpu::ConvertLayoutOp::create( + builder, op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + return newOp; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + // mark condition as live. + markLive(nestedIf.getCondition()); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + bool changed = false; + // For simplicity we just replace users of the block arg with init value and + // leave the operations and argument removal to dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + changed |= !arg.use_empty(); + rewriter.replaceAllUsesWith(arg, forOp.getTiedLoopInit(arg)->get()); + } + return success(changed); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +ttg::LocalAllocOp findShmemAlloc(Value operand) { + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescIndex op. Only ConvertLayout and MemdescView ops are + // allowed in between. + Value transitiveOperand = operand; + while (isa_and_nonnull( + transitiveOperand.getDefiningOp()) || + isa(transitiveOperand)) { + if (auto blockArg = dyn_cast(transitiveOperand)) { + assert(isa(blockArg.getOwner()->getParentOp()) && + "Block argument must come from a for loop"); + transitiveOperand = + cast(blockArg.getOwner()->getTerminator()) + .getOperand(blockArg.getArgNumber() - 1); + } else { + transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + } + } + if (auto subView = dyn_cast_or_null( + transitiveOperand.getDefiningOp())) { + // Multi-buffered operand + return dyn_cast_or_null( + subView.getSrc().getDefiningOp()); + } else { + // Single bufferred operand that does not require a subview (not loaded in + // the loop) + return dyn_cast_or_null( + transitiveOperand.getDefiningOp()); + } + return nullptr; +} + +SmallVector +getMMAsWithMultiBufferredOperands(scf::ForOp forOp, + SmallVector &mmaOps) { + // The A and B operands of the mmaOp should be multi-buffered + SmallVector eligible; + for (auto mmaOp : mmaOps) { + auto a = findShmemAlloc(mmaOp->getOperand(0)); + auto b = findShmemAlloc(mmaOp->getOperand(1)); + if (a && forOp.isDefinedOutsideOfLoop(a) && b && + forOp.isDefinedOutsideOfLoop(b)) { + eligible.push_back(mmaOp); + } + } + + return eligible; +} + +template +static Operation *findNearestCommonDominatorImpl( + ArrayRef ops, DomInfoT &domInfo, + function_ref isBefore) { + if (ops.size() == 0) { + return nullptr; + } + if (ops.size() == 1) { + return ops[0]; + } + llvm::SmallPtrSet blocks; + for (auto op : ops) { + blocks.insert(op->getBlock()); + } + Block *domBlock = domInfo.findNearestCommonDominator(blocks); + if (domBlock == nullptr) { + return nullptr; + } + SmallVector ancestorOps; + for (auto op : ops) { + ancestorOps.push_back(domBlock->findAncestorOpInBlock(*op)); + } + Operation *dom = ancestorOps[0]; + for (unsigned i = 1; i < ops.size(); i++) { + if (isBefore(ancestorOps[i], dom)) { + dom = ancestorOps[i]; + } + } + return dom; +} + +Operation *findNearestCommonDominator(ArrayRef ops, + DominanceInfo &domInfo) { + return findNearestCommonDominatorImpl( + ops, domInfo, + [](Operation *a, Operation *b) { return a->isBeforeInBlock(b); }); +} + +Operation *findNearestCommonPostDominator(ArrayRef ops, + PostDominanceInfo &domInfo) { + return findNearestCommonDominatorImpl( + ops, domInfo, + [](Operation *a, Operation *b) { return b->isBeforeInBlock(a); }); +} + +void visitNestedOperands(Operation *op, + function_ref visitor) { + op->walk([&](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + if (operand.get().getParentBlock()->getParentOp()->isProperAncestor(op)) + visitor(operand); + } + }); +} + +void visitNestedOperands(Operation *op, function_ref visitor) { + visitNestedOperands(op, [&](OpOperand &operand) { visitor(operand.get()); }); +} + +SetVector getNestedOperands(Operation *op) { + SetVector result; + visitNestedOperands(op, [&](Value operand) { result.insert(operand); }); + return result; +} + +void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) { + // Pad the indices in case new arguments were added. + while (indices.size() != loop.getInitArgs().size()) + indices.push_back(false); + + loop.getBody()->getTerminator()->eraseOperands(indices); + loop.getBody()->eraseArguments([&](BlockArgument arg) { + int idx = arg.getArgNumber(); + return idx != 0 && indices.test(idx - 1); + }); + + llvm::BitVector loopOperandIndices(loop->getNumOperands()); + for (auto [i, operand] : llvm::enumerate(loop.getInitArgsMutable())) { + if (indices.test(i)) + loopOperandIndices.set(operand.getOperandNumber()); + } + loop->eraseOperands(loopOperandIndices); + + // Rewrite the loop to erase results. + OperationState state(loop.getLoc(), loop->getName(), loop->getOperands(), + loop.getInitArgs().getTypes(), loop->getAttrs()); + state.addRegion()->takeBody(loop.getBodyRegion()); + + OpBuilder b(loop); + auto newLoop = cast(b.create(state)); + + // Replace uses of the old loop with the new loop. + unsigned newResultIdx = 0; + for (auto [i, result] : llvm::enumerate(loop.getResults())) { + if (indices.test(i)) { + assert(result.use_empty() && "loop carried value still has uses"); + continue; + } + result.replaceAllUsesWith(newLoop.getResult(newResultIdx++)); + } + + loop.erase(); + loop = newLoop; +} + +} // namespace mlir + +namespace mlir::triton { + +void replaceUsesAndPropagateType( + OpBuilder &builder, Operation *oldUse, Value val, + std::function callback) { + OpBuilder::InsertionGuard guard(builder); + SmallVector opsToDelete; + SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? + for (OpOperand &use : oldUse->getUses()) { + // Propagate through `ttg.warp_specialize`. + if (auto wsOp = dyn_cast(use.getOwner())) { + for (Region ®ion : wsOp.getPartitionRegions()) + region.getArgument(use.getOperandNumber()).setType(val.getType()); + } + + // Non-subview/trans ops will be replaced by `val`. + if (!use.getOwner()->hasTrait()) { + operandsToReplace.push_back(&use); + continue; + } + + Operation *user = use.getOwner(); + // `subview(old_op)` is replaced by a new `subview(val)`. + builder.setInsertionPoint(user); + Value newVal; + if (auto subview = dyn_cast(user)) { + ttg::MemDescType oldType = subview.getType(); + bool isMutable = cast(val.getType()).getMutableMemory(); + Type newDstType = ttg::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable); + newVal = ttg::MemDescIndexOp::create(builder, subview.getLoc(), + newDstType, val, subview.getIndex()); + } else if (auto subslice = dyn_cast(user)) { + ttg::MemDescType oldType = subslice.getType(); + bool isMutable = cast(val.getType()).getMutableMemory(); + Type newDstType = ttg::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable, oldType.getAllocShape()); + newVal = ttg::MemDescSubsliceOp::create( + builder, subslice.getLoc(), newDstType, val, subslice.getOffsets()); + } else if (auto trans = dyn_cast(user)) { + newVal = ttg::MemDescTransOp::create(builder, trans.getLoc(), val, + trans.getOrder()); + } else if (auto reshape = dyn_cast(user)) { + auto shape = reshape.getType().getShape(); + newVal = + ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), val, shape); + } + assert(newVal && "unhandled memdesc view"); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + replaceUsesAndPropagateType(builder, user, newVal); + opsToDelete.push_back(user); + if (callback) { + callback(user, newVal.getDefiningOp()); + } + } + + // Perform late replacement. + for (OpOperand *operand : operandsToReplace) { + if (isPassthroughWaitLikeOp(operand->getOwner())) { + Operation *wait = operand->getOwner(); + builder.setInsertionPointAfter(wait); + auto operands = llvm::to_vector(wait->getOperands()); + operands[operand->getOperandNumber()] = val; + FailureOr newWait = + rebuildPassthroughWaitLikeOp(builder, wait, operands); + if (failed(newWait)) + llvm::report_fatal_error("failed to rebuild passthrough wait-like op " + "during type propagation"); + wait->replaceAllUsesWith((*newWait)->getResults()); + wait->erase(); + } else { + operand->set(val); + } + } + + // Perform late op erasure. + for (Operation *op : opsToDelete) + op->erase(); +} + +ttg::LocalLoadOp +replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old, + TypedValue alloc, + TypedValue token) { + // Remove redundant local_load -> local_alloc + auto allocTy = alloc.getType(); + SmallVector allocsToErase; + for (Operation *user : old.getUsers()) { + if (auto userAlloc = dyn_cast(user)) { + if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) { + replaceUsesAndPropagateType(builder, userAlloc, alloc); + allocsToErase.push_back(userAlloc); + } + } + } + + // If there are some uses that were not local_allocs, we need to create a + // local_load for them. + ttg::LocalLoadOp maybeLocalLoad; + if (std::distance(old.getUsers().begin(), old.getUsers().end()) > + allocsToErase.size()) { + auto loc = old.getOwner()->getLoc(); + maybeLocalLoad = + ttg::LocalLoadOp::create(builder, loc, old.getType(), alloc, token); + old.replaceAllUsesWith(maybeLocalLoad); + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + return maybeLocalLoad; +} + +bool comesFromLoadOrBlockArg(Value v) { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + Operation *def = v.getDefiningOp(); + if (!def) + break; + if (auto cvtOp = dyn_cast(def)) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = dyn_cast(def)) { + v = transOp.getSrc(); + continue; + } + if (def->hasTrait()) { + v = def->getOperand(0); + continue; + } + break; + } + // We also accept block arguments as they appear in many MLIR tests + // If this is problematic we can totally drop them + return isa(v) || + (v.getDefiningOp() && + isa(v.getDefiningOp())); +} + +SmallVector getTiedArgs(Operation *op, int resultIdx) { + if (auto forOp = dyn_cast(op)) { + auto iterArg = forOp.getRegionIterArg(resultIdx); + auto result = forOp.getResult(resultIdx); + auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); + auto initVal = forOp.getInitArgs()[resultIdx]; + return {iterArg, result, yieldVal, initVal}; + } else if (auto whileOp = dyn_cast(op)) { + auto iterArg = whileOp.getBeforeArguments()[resultIdx]; + auto result = whileOp.getResults()[resultIdx]; + auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx]; + auto initVal = whileOp.getOperands()[resultIdx]; + auto bodyArg = whileOp.getAfterArguments()[resultIdx]; + return {iterArg, result, yieldVal, initVal, bodyArg}; + } else if (auto ifOp = dyn_cast(op)) { + SmallVector values; + for (auto &block : ifOp.getThenRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + for (auto &block : ifOp.getElseRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + values.push_back(ifOp->getResults()[resultIdx]); + return values; + } + return {}; +} + +LogicalResult verifyBarrierType(Operation *op, + mlir::triton::gpu::MemDescType barrierType) { + auto numCTAs = triton::gpu::lookupNumCTAs(op); + if (!(barrierType.getElementType().isInteger(64) && + barrierType.getRank() == 1 && barrierType.getShape()[0] <= numCTAs)) + return op->emitOpError("barrier allocation must be a descriptor of " + "Nxi64 type with N <= number of CTAs"); + return success(); +} + +std::optional getBoolFromConstant(Value cst) { + auto constantOp = cst.getDefiningOp(); + if (!constantOp) { + return std::nullopt; + } + assert(constantOp.getValue()); + if (auto boolAttr = dyn_cast(constantOp.getValue())) { + return boolAttr.getValue(); + } + return std::nullopt; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp new file mode 100644 index 0000000000..406e39c0cd --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp @@ -0,0 +1,83 @@ +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUAUTOMATICWARPSPECIALIZATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct AutomaticWarpSpecialization + : triton::gpu::impl::TritonGPUAutomaticWarpSpecializationBase< + AutomaticWarpSpecialization> { + using TritonGPUAutomaticWarpSpecializationBase:: + TritonGPUAutomaticWarpSpecializationBase; + + void runOnOperation() override; +}; + +void multiBufferTMADescriptors(ModuleOp mod, int numStages) { + SetVector descUpdateLoops; + mod.walk([&](scf::ForOp loop) { + if (loop->hasAttr(kWarpSpecializeAttrName)) { + loop.walk([&](triton::MakeTensorDescOp op) { + if (auto forOp = op->getParentOfType()) { + descUpdateLoops.insert(forOp); + } + }); + } + }); + + // +1 to make sure that overlapping of the next desc update and the oldest + // inflight TMA load is safe + const int numDescs = numStages + 1; + // CoarseSchedule's notion of numStages is the maximuim loop-pipelining + // stage + 1, see CoarseSchedule::deSerialize(). So if we want n buffers, + // we need to pass n + 1 as numStages. + triton::CoarseSchedule schedule(numDescs + 1); + + for (auto loop : descUpdateLoops) { + triton::lowerTMADescriptors(loop, schedule); + } +} + +} // namespace + +void AutomaticWarpSpecialization::runOnOperation() { + OpPassManager pm; + pm.addPass(createTritonGPUPartitionScheduling()); + pm.addPass(createNVWSHoistTmemStore()); + pm.addPass(createNVWSInsertAref()); + pm.addPass(createNVWSInsertTmemAref()); + // `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic. + // FIXME: Re-enable integer range analysis once it is fixed. + // pm.addPass(arith::createIntRangeOptimizationsPass()); + pm.addPass(createSCCPPass()); + pm.addPass(createCSEPass()); + pm.addPass(createNVWSLowerAref({numStages})); + pm.addPass(createTritonGPUPartitionLoops()); + pm.addPass(createNVWSLowerWarpGroup()); + pm.addPass(createTritonGPUScheduleLoops()); + if (failed(runPipeline(pm, getOperation()))) + return signalPassFailure(); + + // Multi-buffer TMA descriptors. We cannot rely on SWP to do it, to support + // desc updates in nested loops. + multiBufferTMADescriptors(getOperation(), numStages); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp new file mode 100644 index 0000000000..a1cc0f1dd4 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp @@ -0,0 +1,318 @@ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +namespace ttng = triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// relayoutWarps +//===----------------------------------------------------------------------===// + +using RunPipelineFn = function_ref; + +// Take the body of a partition into a new `tt.func`. We can use this to run a +// full compiler pipeline on the partition. +static OwningOpRef takeIntoFunction(ModuleAxisInfoAnalysis &axisInfo, + Region *partition, int numWarps) { + // Forward the module attributes (target, number of threads per warp, etc.) + // onto the container module. + ModuleOp mod = axisInfo.getModuleOp(); + OwningOpRef container = ModuleOp::create(mod.getLoc()); + Block *containerBlock = container->getBody(); + + auto b = OpBuilder::atBlockBegin(containerBlock); + FunctionType funcType = b.getFunctionType(partition->getArgumentTypes(), {}); + auto containerFunc = FuncOp::create(b, mod.getLoc(), "container", funcType); + containerFunc.getBody().takeBody(*partition); + container.get()->setAttrs(mod->getAttrs()); + container.get()->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps)); + + // Replace `ttg.warp_return` with `tt.return` to make the IR valid. + containerFunc.walk([&](WarpReturnOp op) { + b.setInsertionPoint(op); + ReturnOp::create(b, op.getLoc()); + op.erase(); + }); + + // This should make valid IR. + if (failed(mlir::verify(*container))) + llvm::report_fatal_error("expected partition region to make valid IR"); + + // Attach axis info properties. + auto wsOp = partition->getParentOfType(); + auto *funcInfo = + axisInfo.getFuncData(wsOp->getParentOfType()); + assert(funcInfo && "expected to find function axis info"); + for (auto [i, capture] : + llvm::enumerate(wsOp.getPartitionOp().getExplicitCaptures())) { + AxisInfo info = funcInfo->lookup(capture); + containerFunc.setArgAttr(i, "tt.contiguity", + b.getI64IntegerAttr(info.getContiguity(0))); + containerFunc.setArgAttr(i, "tt.divisibility", + b.getI64IntegerAttr(info.getDivisibility(0))); + containerFunc.setArgAttr(i, "tt.constancy", + b.getI64IntegerAttr(info.getConstancy(0))); + } + + return container; +} + +// Take the partition body out of the container module and function. +static void extractPartitionBody(OwningOpRef container, + Region *partition) { + auto containerFunc = cast(container->lookupSymbol("container")); + + // Rewrite the returns. + containerFunc.walk([](ReturnOp op) { + OpBuilder b(op); + WarpReturnOp::create(b, op.getLoc()); + op.erase(); + }); + + partition->takeBody(containerFunc.getBody()); +} + +// Reset the layouts of operations in a region and re-run layout assignment. +static LogicalResult relayoutWarps(ModuleAxisInfoAnalysis &axisInfo, + Region *partition, int prevNumWarps, + int newNumWarps, RunPipelineFn runPipeline) { + OwningOpRef container = + takeIntoFunction(axisInfo, partition, prevNumWarps); + + // Start by removing all tensor encodings. + mlir::AttrTypeReplacer replacer; + replacer.addReplacement( + [](RankedTensorType ty) { return ty.cloneWithEncoding({}); }); + // But don't remove them from the tensors inside descriptors. + replacer.addReplacement([](TensorDescType ty) -> std::pair { + return {ty, WalkResult::skip()}; + }); + replacer.recursivelyReplaceElementsIn(*container, /*replaceAttrs=*/false, + /*replaceLocs=*/false, + /*replaceTypes=*/true); + + ModuleOp mod = axisInfo.getModuleOp(); + auto target = mod->getAttrOfType(AttrTargetName); + if (!target) + return mlir::emitError(mod.getLoc(), "module missing target specification"); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + + // Enable `convert-triton-to-tritongpu` to rematerialize source layouts for + // TTG dialect operations. They will get cleared later. + OpPassManager pm; + pm.addPass( + createConvertTritonToTritonGPU({target.str(), newNumWarps, threadsPerWarp, + numCTAs, /*enableSourceRemat=*/true})); + pm.addPass(createRelayoutTritonGPU()); + if (failed(runPipeline(pm, *container))) + return failure(); + // Clear source rematerializations by propagating the source layout. + container->walk([](UnrealizedConversionCastOp op) { + op.getResult(0).replaceAllUsesWith(op.getOperand(0)); + op.erase(); + }); + + pm.clear(); + pm.addPass(createTritonGPUCoalesce()); + pm.addPass(createTritonGPURemoveLayoutConversions()); + pm.addPass(createTritonGPUOptimizeThreadLocality()); + pm.addPass(createTritonGPUAccelerateMatmul()); + pm.addPass(createTritonGPURemoveLayoutConversions()); + if (failed(runPipeline(pm, *container))) + return failure(); + + extractPartitionBody(std::move(container), partition); + return success(); +} + +//===----------------------------------------------------------------------===// +// optimizePartitionWarps +//===----------------------------------------------------------------------===// + +// Get the number of i32 registers required to store a tensor. +static unsigned getTensorNumI32Regs(RankedTensorType ty) { + unsigned numElems = getTotalElemsPerThread(ty) * + product(getThreadsPerWarp(ty)) * + product(getWarpsPerCTA(ty)); + unsigned elSize = + isa(ty.getElementType()) ? 64 : ty.getElementTypeBitWidth(); + return numElems * elSize / 32; +} + +static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo, + WarpSpecializeOp wsOp, + RunPipelineFn runPipeline) { + // Extremely rough estimate of the number of registers needed per partition. + // For each partition, get the number of i32 registers used by the largest + // tensor value. + // + // Because the partition region is isolated from above, we could in theory + // compile it to PTX and read the number of registers that got allocated. + SmallVector maxTensorRegs; + for (Region *partition : wsOp.getPartitionRegions()) { + unsigned &tensorRegs = maxTensorRegs.emplace_back(0); + partition->walk([&](Operation *op) { + for (Type type : + llvm::concat(op->getOperandTypes(), op->getResultTypes())) { + if (auto tensor = dyn_cast(type)) + tensorRegs = std::max(tensorRegs, getTensorNumI32Regs(tensor)); + } + }); + // Assume that the largest tensor accounts for half of the registers used + // by a warpgroup. + tensorRegs *= 2; + } + + // Reduce the number of warps used by partitions. For partitions with no + // tensor computations, always reduce them to 1 warp. + // + // We can't use `nvvm.setmaxnreg` because this requires a known value for + // `maxnreg` on the kernel, which is currently controlled by the frontend. + // Thus, assume PTXAS will evenly distribute the total pool of registers + // across all warps. + // + // If the compiler could control that, then we could allow non-uniform + // register distributions, mostly beneficial for single-warp warpgroups that + // just do some artihmetic. + constexpr unsigned nTotalRegs = 1 << 16; // for Blackwell SMs + const unsigned threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(axisInfo.getModuleOp()); + const unsigned defaultNumWarps = lookupNumWarps(wsOp); + + SmallVector partitionNumWarps = + llvm::to_vector(wsOp.getPartitionNumWarps()); + + // Determine if a partition has a lower limit on the number of warps. + SmallVector minWarpsForPartition(partitionNumWarps.size(), 1); + for (auto [minWarps, region] : + llvm::zip(minWarpsForPartition, wsOp.getPartitionRegions())) { + region->walk([minWarps = &minWarps](Operation *op) { + // Some instructions have critical throughput if have low register usage. + // Make sure there are enough warps for these ops to execute quickly. + if (isa(op)) + *minWarps = 2; + // TMEM ops require at least 4 warps to be able to read all lanes. + else if (isa(op)) + *minWarps = 4; + }); + } + + bool changed; + do { + changed = false; + + // Assuming even distribution of registers, given the total number of warps + // currently allocated, we can guess the number of registers PTXAS will + // distribute to each warp. + // + // For example, given 18 warps and a tensor<128x256xf32> contained in an + // 8-warp partition, we have (nTotalRegs/32/18) = ~113 regs per thread, and + // the tensor requires 128 regs per thread in its partition. In this case, + // nothing can be done. + // + // However, given a tensor<128x128xf32>, this requires only 64 regs per + // thread in 8 warps. If we reduce the size of the warp to 4, the overall + // regs per thread increases to (nTotalRegs/32/14) = ~146 regs per thread, + // while the tensor now requires 128 regs per thread. This works. + // + // The next iteration sees ~170 regs per thread, but the tensor will require + // 256, which is too many. So the algorithm stops at 4 warps. Evidently, if + // there are other partitions that can be reduced, we have to iterate this + // algorithm. + int32_t curTotalNumWarps = std::accumulate( + partitionNumWarps.begin(), partitionNumWarps.end(), defaultNumWarps); + + for (auto [minWarps, numWarps, tensorRegs] : + llvm::zip(minWarpsForPartition, partitionNumWarps, maxTensorRegs)) { + if (numWarps <= minWarps) + continue; + // Check if reducing the number of warps will still fit the tensor. If it + // didn't fit to begin with, it won't fit after shrinking. + unsigned reqRegsPerThread = tensorRegs / threadsPerWarp / (numWarps / 2); + unsigned nextTotalNumWarps = curTotalNumWarps - (numWarps / 2); + unsigned nextRegsPerThread = + nTotalRegs / threadsPerWarp / nextTotalNumWarps; + if (reqRegsPerThread <= nextRegsPerThread) { + numWarps /= 2; + changed = true; + break; + } + } + } while (changed); + + SmallVector estRegUsage(partitionNumWarps.size()); + for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] : + llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps, + wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) { + // "Guess" the register usage for each partition. + estRegs = tensorRegs ? 88 : 24; + + // Layouts need to be reassigned if the number of warps changed and there + // are tensor computations. + if (newNumWarps == prevNumWarps || !tensorRegs) + continue; + // We need to reassign layouts. + if (failed(relayoutWarps(axisInfo, partition, prevNumWarps, newNumWarps, + runPipeline))) + return failure(); + } + wsOp.setRequestedRegisters(estRegUsage); + wsOp.setPartitionNumWarps(partitionNumWarps); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEPARTITIONWARPS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct OptimizePartitionWarps + : triton::gpu::impl::TritonGPUOptimizePartitionWarpsBase< + OptimizePartitionWarps> { + using TritonGPUOptimizePartitionWarpsBase:: + TritonGPUOptimizePartitionWarpsBase; + + void runOnOperation() override; +}; +} // namespace + +void OptimizePartitionWarps::runOnOperation() { + SmallVector wsOps; + getOperation().walk([&](WarpSpecializeOp wsOp) { wsOps.push_back(wsOp); }); + + if (wsOps.empty()) { + return; + } + + ModuleAxisInfoAnalysis axisInfo(getOperation()); + auto runPipelineFn = [&](OpPassManager &pm, ModuleOp container) { + // The module must be directly nested under the current op for `runPipeline` + // to work. + getOperation().push_back(container); + auto remove = llvm::make_scope_exit([&] { container->remove(); }); + return runPipeline(pm, container); + }; + + for (auto wsOp : wsOps) { + if (failed(optimizePartitionNumWarps(axisInfo, wsOp, runPipelineFn))) { + return signalPassFailure(); + } + } +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp new file mode 100644 index 0000000000..69ef8bf55c --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp @@ -0,0 +1,244 @@ +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Use.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +//===----------------------------------------------------------------------===// +// Partition +//===----------------------------------------------------------------------===// + +bool Partition::hasOp(Operation *op) const { + if (!hasPartition(op)) { + return false; + } + auto partitionIds = getPartitionIds(op); + return partitionIds.contains(getIndex()); +} + +void Partition::iterateInputs(scf::ForOp loop, + function_ref callback) const { + for (Operation *op : getOps()) { + visitNestedOperands(op, [&](OpOperand &operand) { + // Ignore implicit captures. + Value value = operand.get(); + std::optional> partitionIds; + if (hasPartition(value.getDefiningOp())) + partitionIds = getPartitionIds(value.getDefiningOp()); + if (value.getParentBlock() != loop.getBody()) + return; + if (auto arg = dyn_cast(value)) { + assert(arg.getOwner() == loop.getBody()); + // Ignore the induction variable. + if (arg == loop.getInductionVar()) + return; + // This value originates from a previous iteration. + assert(llvm::is_contained(loop.getRegionIterArgs(), arg)); + callback(operand); + } else if (!partitionIds || + !llvm::is_contained(*partitionIds, getIndex())) { + // This value originates from a different partition in the same + // iteration. + assert(value.getDefiningOp()->getParentOp() == loop); + callback(operand); + } + }); + } +} + +void Partition::iterateOutputs( + scf::ForOp loop, + function_ref callback) const { + for (Operation *op : getOps()) { + for (OpOperand &use : op->getUses()) { + Operation *owner = loop.getBody()->findAncestorOpInBlock(*use.getOwner()); + if (!owner) { + continue; + } + std::optional> partitionIds; + if (hasPartition(owner)) + partitionIds = getPartitionIds(owner); + if (isa(owner)) { + // This value is used in a subsequent iteration. + callback(owner, use); + } else if (!partitionIds || + !llvm::is_contained(*partitionIds, getIndex())) { + // This value is used in a different partition in the same iteration. + callback(owner, use); + } + } + } +} + +void Partition::iterateDefs( + scf::ForOp loop, function_ref callback) const { + iterateInputs(loop, [&](OpOperand &input) { + auto [def, distance] = getDefinitionAndDistance(loop, input.get()); + if (def && def.getParentBlock() == loop.getBody()) + callback(def, distance); + }); +} + +void Partition::iterateUses( + scf::ForOp loop, + function_ref callback) const { + SmallVector> uses; + iterateOutputs(loop, [&](Operation *owner, OpOperand &use) { + uses.emplace_back(cast(use.get()), &use, 0); + }); + while (!uses.empty()) { + auto [output, use, distance] = uses.pop_back_val(); + Operation *owner = loop.getBody()->findAncestorOpInBlock(*use->getOwner()); + if (!owner) { + continue; + } + if (!isa(owner)) { + callback(output, *use, distance); + continue; + } + BlockArgument arg = loop.getRegionIterArg(use->getOperandNumber()); + for (OpOperand &use : arg.getUses()) + uses.emplace_back(output, &use, distance + 1); + } +} + +//===----------------------------------------------------------------------===// +// PartitionSet +//===----------------------------------------------------------------------===// + +Partition *PartitionSet::addPartition(unsigned stage) { + partitions.push_back(std::make_unique(partitions.size(), stage)); + return partitions.back().get(); +} + +Partition *PartitionSet::getPartition(unsigned idx) { + return partitions[idx].get(); +} + +const Partition *PartitionSet::getPartition(unsigned idx) const { + return partitions[idx].get(); +} + +Partition *PartitionSet::getPartition(Operation *op) { + auto id = getPartitionIds(op); + assert(id.size() == 1); + return getPartition(id[0]); +} + +FailureOr PartitionSet::fromLoop(scf::ForOp loop) { + auto stages = loop->getAttrOfType(kPartitionStagesAttrName); + if (!stages) + return failure(); + + auto tag = loop->getAttrOfType(kWarpSpecializeTagAttrName); + if (!tag) + return failure(); + + PartitionSet result; + result.tag = tag.getInt(); + for (auto [idx, attr] : llvm::enumerate(stages)) { + auto stage = dyn_cast(attr); + if (!stage || stage.getInt() < 0) { + return mlir::emitError(loop.getLoc(), "partition stages attribute '") + << kPartitionStagesAttrName << "' has invalid element " << attr; + } + + result.partitions.push_back( + std::make_unique(idx, stage.getInt())); + } + + SmallVector annotatedOps; + loop->walk([&](Operation *op) { + if (hasPartition(op)) { + annotatedOps.push_back(op); + } + }); + + for (auto op : annotatedOps) { + auto attrs = getPartitionIds(op); + for (auto idx : attrs) { + if (idx < 0 || idx >= result.partitions.size()) + return mlir::emitError(op->getLoc(), "invalid partition index ") << idx; + result.partitions[idx]->addOp(op); + } + } + + return result; +} + +void PartitionSet::dump() const { + for (auto [i, partition] : + llvm::enumerate(llvm::make_pointee_range(partitions))) { + llvm::errs() << "=== PARTITION #" << i << " ===\n"; + for (Operation *op : partition.getOps()) { + op->print(llvm::errs(), OpPrintingFlags().skipRegions()); + llvm::errs() << "\n"; + } + llvm::errs() << "\n"; + } + llvm::errs() << "\n"; +} + +namespace mlir::triton::gpu { + +void setPartition(Operation *op, ArrayRef partitionIds) { + Builder b(op->getContext()); + auto sorted = llvm::to_vector(partitionIds); + llvm::sort(sorted); + op->setAttr(kPartitionAttrName, b.getDenseI32ArrayAttr(sorted)); + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + auto terminator = block.getTerminator(); + terminator->setAttr(kPartitionAttrName, b.getDenseI32ArrayAttr(sorted)); + } + } +} + +void setPartitionOutputs(Operation *op, + ArrayRef> partitionOutputsIds) { + if (partitionOutputsIds.empty()) { + op->removeAttr(kPartitionOutputsAttrName); + return; + } + SmallVector attrs; + Builder b(op->getContext()); + for (auto partitionIds : partitionOutputsIds) { + auto sorted = llvm::to_vector(partitionIds); + llvm::sort(sorted); + attrs.push_back(b.getDenseI32ArrayAttr(sorted)); + } + op->setAttr(kPartitionOutputsAttrName, b.getArrayAttr(attrs)); +} + +void setPartition(Operation *op, const SetVector &partitionIds) { + SmallVector partitions(partitionIds.begin(), partitionIds.end()); + setPartition(op, partitions); +} + +void setPartition(Operation *op, Partition *partition) { + SmallVector partitions{partition->getIndex()}; + setPartition(op, partitions); + partition->addOp(op); +} + +void setPartition(Operation *op, const SetVector &partitions) { + SmallVector partitionIds; + for (auto partition : partitions) { + partitionIds.push_back(partition->getIndex()); + partition->addOp(op); + } + setPartition(op, partitionIds); +} + +void setWarpSpecializeTag(Operation *op, int tag) { + Builder b(op->getContext()); + op->setAttr(kWarpSpecializeTagAttrName, b.getI32IntegerAttr(tag)); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionBuilder.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionBuilder.cpp new file mode 100644 index 0000000000..8d18c1fab1 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionBuilder.cpp @@ -0,0 +1,36 @@ +#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +Value PartitionBuilder::intCst(int value, unsigned width) { + return create(value, width); +} + +Value PartitionBuilder::boolCst(bool value) { + return intCst(value, /*width=*/1); +} + +void PartitionBuilder::assignPartition(Operation *op, Partition &partition) { + setPartition(op, &partition); +} + +StageCluster triton::gpu::getStageCluster(Operation *op) { + auto stageAttr = op->getAttrOfType(kLoopStageAttrName); + auto clusterAttr = op->getAttrOfType(kLoopClusterAttrName); + if (!stageAttr || !clusterAttr) + return std::nullopt; + return std::make_pair(stageAttr.getInt(), clusterAttr.getInt()); +} + +void triton::gpu::setStageCluster(OpBuilder &b, Operation *op, + StageCluster stageCluster) { + if (stageCluster) { + op->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(stageCluster->first)); + op->setAttr(kLoopClusterAttrName, + b.getI32IntegerAttr(stageCluster->second)); + } +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp new file mode 100644 index 0000000000..a7a937e734 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp @@ -0,0 +1,545 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Dialect/NVWS/IR/Dialect.h" +#include "triton/Dialect/NVWS/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h" +#include "llvm/ADT/SCCIterator.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +namespace { + +struct WarpGroupBuilder : public OpBuilder { + WarpGroupBuilder(Block *block, Block::iterator insertPoint, + size_t partitionId) + : OpBuilder(block, insertPoint), partitionId(partitionId) {} + + IRMapping mapping; + size_t partitionId; +}; + +// This is computed per loop and partition +enum class LoopVarCategory { + // The given loop variable is not used by the given partition. For example, + // the use-D flag for MMA is only used by the MMA partition, and thus + // is `Unused` for any other partition. + Unused, + // The given loop variable is used by the given partition. For example, a loop + // index might be used to compute a relevant stage or phase value for the + // given partition. + Used, + // The results of warp_group op are defined to be those of the first + // partition. If the original loop results include a tensor which is computed + // only by a non-default partition, such tensor cannot be returned from the + // first partition and and must be passed through shared memory. The + // corresponding loop variable falls into this category. + // Recognizing this category is necessary for the first partition. For other + // partitions, some loop variables might be assigned this category, but that + // information is not used. + TensorResultFromOtherPartition, +}; + +SetVector getResultPartitionIds(Operation *op, int index) { + return getPartitionOutputs(op)[index]; +} + +SetVector getIfOpResultPartitionIds(scf::IfOp ifOp, Value value) { + for (auto result : ifOp.getResults()) { + if (result == value) { + auto pos = result.getResultNumber(); + return getResultPartitionIds(ifOp, pos); + } + } + llvm_unreachable("value is not a result of if-stmt"); +} + +bool isTensorResultComputedBy(scf::ForOp loop, size_t resultIdx, + const Partition *partition, + const PartitionSet &partitions) { + auto value = loop.getYieldedValues()[resultIdx]; + if (!isa(value.getType())) + return false; + auto defOp = value.getDefiningOp(); + auto partitionIds = getPartitionIds(defOp); + if (auto ifOp = dyn_cast(defOp)) { + partitionIds = getIfOpResultPartitionIds(ifOp, value); + } + return llvm::is_contained(partitionIds, partition->getIndex()); +} + +SmallVector classifyLoopVars(scf::ForOp loop, + const Partition *partition, + const PartitionSet &partitions) { + auto isTensorResultFromOtherPartition = [&](int i) { + for (auto otherPartition : partitions.getPartitions()) { + if (&otherPartition == partition) { + continue; + } + if (isTensorResultComputedBy(loop, i, &otherPartition, partitions)) { + return true; + } + } + return false; + }; + + SmallVector categories(loop.getNumRegionIterArgs()); + for (auto [i, arg] : llvm::enumerate(loop.getRegionIterArgs())) { + auto partitionIds = getResultPartitionIds(loop, i); + if (llvm::is_contained(partitionIds, partition->getIndex())) { + categories[i] = LoopVarCategory::Used; + } else if (isTensorResultFromOtherPartition(i) && + !loop.getResult(i).use_empty()) { + categories[i] = LoopVarCategory::TensorResultFromOtherPartition; + } else { + categories[i] = LoopVarCategory::Unused; + } + } + + return categories; +} + +std::pair, SmallVector>> +getLoopVarIndicesToKeep(scf::ForOp loop, const Partition *partition, + ArrayRef loopVarCategories) { + SmallVector indices; + // The null index means an invalid index, the corresponding loop variable in + // the original loop is removed in the cloned loop + SmallVector> reverseIndices(loop.getNumRegionIterArgs(), + std::nullopt); + for (auto [i, arg] : llvm::enumerate(loop.getRegionIterArgs())) { + if (loopVarCategories[i] == LoopVarCategory::Used) { + reverseIndices[i] = indices.size(); + indices.push_back(i); + } + } + return std::make_pair(indices, reverseIndices); +} + +std::pair, SmallVector>> +getLoopVarIndicesToKeep(scf::ForOp loop, const Partition *partition, + const PartitionSet &partitions) { + auto loopVarCategories = classifyLoopVars(loop, partition, partitions); + return getLoopVarIndicesToKeep(loop, partition, loopVarCategories); +} + +void mapRange(ValueRange fromRange, ValueRange toRange, IRMapping &mapping) { + for (auto [from, to] : llvm::zip(fromRange, toRange)) { + mapping.map(from, to); + } +} + +void cloneOpsInBlock(Block *block, SmallVector &builders, + const PartitionSet &partitions); + +void cloneForOp(scf::ForOp forOp, SmallVector &builders, + const PartitionSet &partitions) { + auto forOpPartitions = getPartitionIds(forOp); + + SmallVector newForOps; + for (int i : forOpPartitions) { + auto &b = builders[i]; + auto partition = partitions.getPartition(i); + auto [newLoopIndices, _] = + getLoopVarIndicesToKeep(forOp, partition, partitions); + auto lb = b.mapping.lookupOrDefault(forOp.getLowerBound()); + auto ub = b.mapping.lookupOrDefault(forOp.getUpperBound()); + auto step = b.mapping.lookupOrDefault(forOp.getStep()); + SmallVector initArgs; + for (auto idx : newLoopIndices) { + initArgs.push_back(b.mapping.lookupOrDefault(forOp.getInitArgs()[idx])); + } + auto newForOp = + scf::ForOp::create(b, forOp.getLoc(), lb, ub, step, initArgs); + newForOp->setAttrs(forOp->getAttrs()); + if (forOp->hasAttr(kPartitionOutputsAttrName)) { + newForOp->removeAttr(kPartitionOutputsAttrName); + } + newForOps.push_back(newForOp); + + b.mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + auto oldIterArgs = forOp.getRegionIterArgs(); + auto newIterArgs = newForOp.getRegionIterArgs(); + for (auto [newIdx, oldIdx] : llvm::enumerate(newLoopIndices)) { + b.mapping.map(oldIterArgs[oldIdx], newIterArgs[newIdx]); + b.mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); + } + + b.setInsertionPointToStart(newForOp.getBody()); + } + + cloneOpsInBlock(forOp.getBody(), builders, partitions); + + for (auto [i, newForOp] : llvm::zip(forOpPartitions, newForOps)) { + builders[i].setInsertionPointAfter(newForOp); + newForOp.walk([&](Operation *op) { op->removeAttr(kPartitionAttrName); }); + newForOp->removeAttr(kPartitionStagesAttrName); + } +} + +void cloneIfOp(scf::IfOp ifOp, SmallVector &builders, + const PartitionSet &partitions) { + auto partitionIndices = getPartitionIds(ifOp); + + SmallVector newIfOps; + for (size_t idx : partitionIndices) { + auto &b = builders[idx]; + auto cond = b.mapping.lookupOrDefault(ifOp.getCondition()); + SmallVector newIfResultTypes; + SmallVector newIfResultIndices; + for (auto pos = 0; pos < ifOp.getResultTypes().size(); ++pos) { + auto partitionIds = getResultPartitionIds(ifOp, pos); + if (llvm::is_contained(partitionIds, b.partitionId)) { + newIfResultTypes.push_back(ifOp.getResult(pos).getType()); + newIfResultIndices.push_back(pos); + } + } + auto newIfOp = scf::IfOp::create(b, ifOp.getLoc(), newIfResultTypes, cond, + ifOp.elseBlock() ? true : false); + newIfOp->setAttrs(ifOp->getAttrs()); + if (ifOp->hasAttr(kPartitionOutputsAttrName)) { + newIfOp->removeAttr(kPartitionOutputsAttrName); + } + newIfOps.push_back(newIfOp); + + for (auto [newIdx, oldIdx] : llvm::enumerate(newIfResultIndices)) { + b.mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + } + assert(ifOp.thenBlock()->getNumArguments() == 0); + + b.setInsertionPointToStart(newIfOp.thenBlock()); + } + + cloneOpsInBlock(ifOp.thenBlock(), builders, partitions); + + if (auto elseBlock = ifOp.elseBlock()) { + for (auto [idx, newIfOp] : llvm::zip(partitionIndices, newIfOps)) { + builders[idx].setInsertionPointToStart(newIfOp.elseBlock()); + } + cloneOpsInBlock(elseBlock, builders, partitions); + } + + for (auto [idx, newIfOp] : llvm::zip(partitionIndices, newIfOps)) { + builders[idx].setInsertionPointAfter(newIfOp); + } +} + +void cloneReduceOp(triton::ReduceOp reduceOp, + SmallVector &builders, + const PartitionSet &partitions) { + auto partitionIndices = getPartitionIds(reduceOp); + + SmallVector newReduceOps; + for (size_t idx : partitionIndices) { + auto &b = builders[idx]; + + SmallVector srcs; + for (auto src : reduceOp.getSrcs()) { + srcs.push_back(b.mapping.lookupOrDefault(src)); + } + auto axis = reduceOp.getAxis(); + auto newReduceOp = + triton::ReduceOp::create(b, reduceOp.getLoc(), srcs, axis); + newReduceOp->setAttrs(reduceOp->getAttrs()); + if (reduceOp->hasAttr(kPartitionOutputsAttrName)) { + newReduceOp->removeAttr(kPartitionOutputsAttrName); + } + newReduceOps.push_back(newReduceOp); + + mapRange(reduceOp.getResults(), newReduceOp.getResults(), b.mapping); + + auto ®ion = newReduceOp.getRegion(); + Block *block = ®ion.emplaceBlock(); + for (auto arg : reduceOp.getRegion().getBlocks().front().getArguments()) { + auto newArg = block->addArgument(arg.getType(), arg.getLoc()); + b.mapping.map(arg, newArg); + } + + b.setInsertionPointToStart(block); + } + + cloneOpsInBlock(reduceOp.getBody(), builders, partitions); + + for (auto [idx, newReduceOp] : llvm::zip(partitionIndices, newReduceOps)) { + builders[idx].setInsertionPointAfter(newReduceOp); + } +} + +void cloneOp(Operation *op, SmallVector &builders, + const SetVector &partitionIndices) { + if (op->getNumRegions() != 0) { + llvm::report_fatal_error( + "Ops are expected to be regionless at this point."); + } + + for (size_t idx : partitionIndices) { + auto &builder = builders[idx]; + auto newOp = builder.clone(*op, builder.mapping); + mapRange(op->getResults(), newOp->getResults(), builder.mapping); + } +} + +void cloneOpsInBlock(Block *block, SmallVector &builders, + const PartitionSet &partitions) { + for (auto &op_ : *block) { + auto op = &op_; + + if (auto forOp = dyn_cast(op)) { + cloneForOp(forOp, builders, partitions); + } else if (auto ifOp = dyn_cast(op)) { + cloneIfOp(ifOp, builders, partitions); + } else if (auto reduceOp = dyn_cast(op)) { + cloneReduceOp(reduceOp, builders, partitions); + } else if (auto yieldOp = dyn_cast(op)) { + if (yieldOp.getOperands().empty()) { + continue; + } + // empty yield has no partition annotations + assert(hasPartition(op)); + auto partitionIndices = getPartitionIds(op); + + for (size_t idx : partitionIndices) { + auto &builder = builders[idx]; + SmallVector newOperandIndices; + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + newOperandIndices = + getLoopVarIndicesToKeep( + forOp, partitions.getPartition(builder.partitionId), + partitions) + .first; + } else { + auto ifOp = cast(yieldOp->getParentOp()); + for (size_t i = 0; i < yieldOp.getOperands().size(); ++i) { + auto ids = getResultPartitionIds(ifOp, i); + if (llvm::is_contained(ids, builder.partitionId)) { + newOperandIndices.push_back(i); + } + } + } + + if (newOperandIndices.empty()) + continue; + + SmallVector newYieldOperands; + for (size_t i : newOperandIndices) { + newYieldOperands.push_back( + builder.mapping.lookupOrDefault(yieldOp.getOperand(i))); + } + + scf::YieldOp::create(builder, op->getLoc(), newYieldOperands); + } + } else { + assert(hasPartition(op)); + auto partitionIndices = getPartitionIds(op); + cloneOp(op, builders, partitionIndices); + } + } +} + +} // namespace + +LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) { + FailureOr partitionsOr = PartitionSet::fromLoop(loop); + if (failed(partitionsOr)) + return failure(); + PartitionSet partitions = std::move(*partitionsOr); + + // Only the root node should have consumers at this point. + for (const Partition &partition : partitions.getPartitions()) { + bool failed = false; + auto callback = [&](OpResult output, OpOperand &use, unsigned distance) { + auto partitionIds = getPartitionIds(use.getOwner()); + if (llvm::is_contained(partitionIds, partition.getIndex())) + return; + + // check if consumer partition set is a subset of the producer partitions + auto defOpPartitionIds = getPartitionIds(output.getDefiningOp()); + bool isValidSubset = std::all_of( + partitionIds.begin(), partitionIds.end(), [&](int consumerId) { + return llvm::is_contained(defOpPartitionIds, consumerId); + }); + + if (isValidSubset) + return; // Valid: consumer ⊆ producer + + failed = true; + InFlightDiagnostic diag = + mlir::emitWarning(output.getLoc(), "non-root partition #") + << partition.getIndex() << " has direct SSA consumer"; + + for (auto partitionId : partitionIds) { + diag.attachNote(use.getOwner()->getLoc()) + << "use at distance " << distance << " in partition #" + << partitionId << " here"; + } + }; + partition.iterateUses(loop, callback); + if (failed) + return failure(); + } + + // There is nothing to do if the loop has 1 or fewer partitions. + if (llvm::size(partitions.getPartitions()) <= 1) + return success(); + + auto numPartitions = partitions.getNumPartitions(); + auto defaultPartition = partitions.getPartition((int)0); + auto loopVarCategories = classifyLoopVars(loop, defaultPartition, partitions); + auto [loopVarIndices, newResultIndices] = + getLoopVarIndicesToKeep(loop, defaultPartition, loopVarCategories); + + ImplicitLocOpBuilder topBuilder(loop.getLoc(), loop); + SmallVector tensorResultAllocs(loop.getNumRegionIterArgs()); + for (auto [i, res] : llvm::enumerate(loop.getResults())) { + if (loopVarCategories[i] == + LoopVarCategory::TensorResultFromOtherPartition) { + auto ty = cast(res.getType()); + auto memdesc = MemDescType::get( + ty.getShape(), ty.getElementType(), getSharedEncoding(ty), + SharedMemorySpaceAttr::get(ty.getContext()), /*mutable=*/true); + tensorResultAllocs[i] = LocalAllocOp::create(topBuilder, memdesc); + } + } + + SmallVector resultTypes; + for (auto i : loopVarIndices) { + resultTypes.push_back(loop.getResultTypes()[i]); + } + + SmallVector numWarps(numPartitions, lookupNumWarps(loop)); + auto wgOp = nvws::WarpGroupOp::create(topBuilder, resultTypes, numWarps, + numPartitions); + + SmallVector builders; + for (Region ®ion : wgOp.getPartitionRegions()) { + auto partitionId = builders.size(); + auto &block = region.emplaceBlock(); + builders.push_back(WarpGroupBuilder(&block, block.end(), partitionId)); + } + + SmallVector opsToErase; + for (auto &op_ : *loop->getBlock()) { + auto op = &op_; + if (!hasPartition(op)) + continue; + assert(hasWarpSpecializeTag(op)); + if (*getWarpSpecializeTag(op) != partitions.getTag()) + continue; + if (op == loop) { + cloneForOp(loop, builders, partitions); + opsToErase.push_back(loop); + } else { + cloneOp(op, builders, getPartitionIds(op)); + opsToErase.push_back(op); + } + } + + for (auto [b, region, partition] : llvm::zip( + builders, wgOp.getPartitionRegions(), partitions.getPartitions())) { + if (!llvm::is_contained(getPartitionIds(loop), b.partitionId)) { + nvws::WarpGroupYieldOp::create(b, wgOp.getLoc(), SmallVector{}); + continue; + } + auto newForOp = *region.front().getOps().begin(); + auto outputs = newForOp.getResults(); + + if (b.partitionId == 0) { + nvws::WarpGroupYieldOp::create(b, wgOp.getLoc(), outputs); + } else { + // Tensor results computed by non-default partitions are communicated back + // via SMEM. + // The calls to getLoopVarIndicesToKeep and isTensorResultComputedBy + // below are unnecessary if we can encode the partition index and the + // corresponding result tensor index of newForOp in + // LoopVarCategory::TensorResultFromOtherPartition. In the absence of such + // language support, we end up computing the same information multiple + // times. + auto [_, reverseIndices] = + getLoopVarIndicesToKeep(loop, &partition, partitions); + for (size_t i = 0; i < loop.getNumRegionIterArgs(); ++i) { + if (loopVarCategories[i] == + LoopVarCategory::TensorResultFromOtherPartition && + isTensorResultComputedBy(loop, i, &partition, partitions)) { + assert(reverseIndices[i] && "A valid index is expected."); + auto result = newForOp.getResult(*reverseIndices[i]); + LocalStoreOp::create(b, wgOp.getLoc(), result, tensorResultAllocs[i]); + } + } + nvws::WarpGroupReturnOp::create(b, wgOp.getLoc()); + } + } + + topBuilder.setInsertionPointAfter(wgOp); + + for (auto [i, res] : llvm::enumerate(loop.getResults())) { + if (res.use_empty()) + continue; + + if (loopVarCategories[i] == + LoopVarCategory::TensorResultFromOtherPartition) { + auto ty = cast(loop.getResult(i).getType()); + auto output = LocalLoadOp::create(topBuilder, ty, tensorResultAllocs[i]); + LocalDeallocOp::create(topBuilder, tensorResultAllocs[i]); + res.replaceAllUsesWith(output); + } else if (llvm::any_of(res.getUsers(), [&](Operation *user) { + return !hasPartition(user) || + (isa(user) && hasWarpSpecializeTag(user)); + })) { + // If some users are in the root partition (no partition attribute) or + // used by another warp-specialized loop, we need to replace their uses + // with the corresponding result from the warp group operation + assert(newResultIndices[i] && "A valid index is expected."); + res.replaceAllUsesWith(wgOp.getResult(*newResultIndices[i])); + } + } + + for (auto op : llvm::reverse(opsToErase)) + op->erase(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUPARTITIONLOOPS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct PartitionLoops + : triton::gpu::impl::TritonGPUPartitionLoopsBase { + using TritonGPUPartitionLoopsBase::TritonGPUPartitionLoopsBase; + + void runOnOperation() override; +}; +} // namespace + +void PartitionLoops::runOnOperation() { + // Collect for loops to warp specialize. This pass expects the loop to already + // be annotated with partitions. + SmallVector loops; + getOperation().walk([&](scf::ForOp loop) { + if (loop->hasAttrOfType(kPartitionStagesAttrName)) + loops.push_back(loop); + }); + + for (scf::ForOp loop : loops) { + if (failed(partitionLoop(loop))) + return signalPassFailure(); + } +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp new file mode 100644 index 0000000000..3b2de017b5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp @@ -0,0 +1,1605 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Partition.h" +#include "triton/Dialect/TritonGPU/Transforms/PartitionSchedulingUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +namespace mlir::triton::gpu { + +// This pass assigns partitions to ops within each warp specialized loop. +// +// Ops are first categorized as either "data" ops (which operate on tiles of +// data, for example load/store/mma ops) or "non-data" ops (for example index +// calculations). +// +// A dataflow graph representation of the program is constructed: every edge in +// the graph represents an MLIR value, and every node represents an MLIR +// operation or block argument. +// +// Initially all nodes for "data" ops are assigned to a new partition. A set of +// heuristics is then applied to every edge that crosses partitions (connects a +// pair of nodes assigned to different partitions). When a heuristic matches, +// the two partitions are merged into a single partition. This is done up until +// a fixed point is reached. A second set of heuristics is run on every +// pair of partitions, merging them until a fixed point is reached. +// +// After the heuristics have been applied, all data ops are assigned to a +// single partition. These partition assignments are then propagated to all +// "non-data" ops. This pulls all of the necessary index calculations etc. into +// the partitions that require them (possibly multiple). +// +// Finally the partition assignments in the dataflow graph are serialized to +// attributes, and the temporary data structure is discarded. + +#define GEN_PASS_DEF_TRITONGPUPARTITIONSCHEDULING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-partition-scheduling" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +using namespace mlir; +using namespace triton; +using namespace partition_scheduling_detail; + +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +using Partition = partition_scheduling_detail::Partition; // resolve ambiguity + +template bool node_isa(Node *node) { + return node->isOp() && isa(node->getOp()); +} + +std::unique_ptr buildGraph(Operation *region) { + DenseMap nodes; + DenseMap, InputPort> operands; + SmallVector> values; + + std::function visitOperation = + [&](Node *graph, Operation *op) { + if (auto funcOp = dyn_cast(op)) { + auto node = graph->addNode(op, 0, 0); + nodes[op] = node; + for (size_t idx = 0; idx < funcOp.getNumArguments(); idx++) { + auto argNode = node->addNode(funcOp.getArgument(idx), 0, 1); + values.push_back(std::make_pair(OutputPort(argNode, 0), + funcOp.getArgument(idx))); + } + for (auto ®ion : op->getRegions()) + for (auto &block : region) + for (auto &op : block) + visitOperation(node, &op); + + } else if (auto forOp = dyn_cast(op)) { + auto node = graph->addNode(op, 3, 0); + nodes[op] = node; + + // lb / ub / step + operands[std::make_pair(op, 0)] = InputPort(node, 0); + operands[std::make_pair(op, 1)] = InputPort(node, 1); + operands[std::make_pair(op, 2)] = InputPort(node, 2); + + // iter args / results + auto ind_var = node->addNode(forOp.getInductionVar(), 0, 1); + node->addDefines(ind_var); + values.push_back( + std::make_pair(OutputPort(ind_var, 0), forOp.getInductionVar())); + size_t idx = 0; + for (auto iter_arg : forOp.getRegionIterArgs()) { + auto iter_arg_node = node->addNode(iter_arg, 2, 1); + node->addDefines(iter_arg_node); + values.push_back( + std::make_pair(OutputPort(iter_arg_node, 0), iter_arg)); + values.push_back(std::make_pair(OutputPort(iter_arg_node, 0), + forOp.getResult(idx))); + idx++; + } + + // init iter args + { + size_t idx = 0; + for (auto operand : forOp.getInitArgs()) { + auto iter_arg_node = node->getDefines()[idx + 1]; + operands[std::make_pair(op, idx + 3)] = + InputPort(iter_arg_node, 0); + idx++; + } + } + + for (auto ®ion : op->getRegions()) + for (auto &block : region) + for (auto &op : block) + visitOperation(node, &op); + + } else if (auto ifOp = dyn_cast(op)) { + auto node = graph->addNode(op, 1, 0); + nodes[op] = node; + + // cond + operands[std::make_pair(op, 0)] = InputPort(node, 0); + + // results + for (auto result : ifOp.getResults()) { + auto result_node = node->addNode(result, 2, 1); + node->addDefines(result_node); + values.push_back( + std::make_pair(OutputPort(result_node, 0), result)); + } + + for (auto ®ion : op->getRegions()) + for (auto &block : region) + for (auto &op : block) + visitOperation(node, &op); + + } else if (auto reduceOp = dyn_cast(op)) { + + auto node = graph->addNode(op, 1, 1); + nodes[op] = node; + + // input + operands[std::make_pair(op, 0)] = InputPort(node, 0); + + // result + assert(reduceOp.getResults().size() == 1); + auto result = reduceOp.getResults().front(); + values.push_back(std::make_pair(OutputPort(node, 0), result)); + + for (auto ®ion : op->getRegions()) + for (auto &block : region) + for (auto &op : block) + visitOperation(node, &op); + + } else if (isa(op)) { + + if (auto forOp = dyn_cast(op->getParentOp())) { + // map operands to yield in a for op to the iter arg nodes + auto for_node = nodes[op->getParentOp()]; + for (size_t idx = 0; idx < op->getNumOperands(); idx++) { + auto block_arg_node = + for_node->getDefines()[idx + 1]; // skip iter arg + operands[std::make_pair(op, idx)] = InputPort(block_arg_node, 1); + } + + } else if (auto ifOp = dyn_cast(op->getParentOp())) { + // map operands to yield in an if op to the if results + auto if_node = nodes[op->getParentOp()]; + for (size_t idx = 0; idx < op->getNumOperands(); idx++) { + auto result_node = if_node->getDefines()[idx]; + operands[std::make_pair(op, idx)] = InputPort( + result_node, + (op->getParentRegion() == &ifOp.getThenRegion()) ? 0 : 1); + } + } else { + assert(false && "unsupported"); + } + + } else if (isa(op)) { + // omit + + } else { + auto node = + graph->addNode(op, op->getNumOperands(), op->getNumResults()); + nodes[op] = node; + for (size_t idx = 0; idx < op->getNumOperands(); idx++) + operands[std::make_pair(op, idx)] = InputPort(node, idx); + for (const auto &result : op->getResults()) + values.push_back(std::make_pair( + OutputPort(node, result.getResultNumber()), result)); + } + }; + + auto graph = std::make_unique(region); + visitOperation(graph->getRoot(), region); + + for (auto [outputPort, value] : values) { + for (auto &use : value.getUses()) { + auto op = use.getOwner(); + auto key = std::make_pair(op, use.getOperandNumber()); + if (operands.find(key) != operands.end()) { + auto inputPort = operands[key]; + Node::addEdge(outputPort, inputPort); + } + } + } + + return graph; +} + +SmallVector initialDataValues(Graph *graph) { + SmallVector values; + graph->walk([&](Node *node) { + if (node->isOp()) { + auto op = node->getOp(); + if (isa(op)) { + node->setDataValue(0); + values.push_back({node, 0}); + } + if (isa(op)) { + node->setDataValue(0); + values.push_back({node, 0}); + node->setDataValue(1); + values.push_back({node, 1}); + } + if (isa(op)) { + node->setDataValue(0); + values.push_back({node, 0}); + } + // if it is manually tagged with data attribute, + // all outputs are treated as data values + if (op->hasAttr("data")) { + for (size_t i = 0; i < node->getNumOutputs(); i++) { + node->setDataValue(i); + values.push_back({node, i}); + } + } + } + }); + return values; +} + +void propagateDataValues(const SmallVector &values) { + SmallVector stack = values; + DenseSet seen; + seen.insert(values.begin(), values.end()); + + auto add = [&](OutputPort value) { + value.getNode()->setDataValue(value.getIdx()); + if (seen.find(value) == seen.end()) { + stack.push_back(value); + seen.insert(value); + } + }; + + while (!stack.empty()) { + auto value = stack.back(); + stack.pop_back(); + for (auto use : value.getNode()->getOutputsFromPort(value.getIdx())) { + auto use_node = use.getNode(); + for (size_t idx = 0; idx < use_node->getNumOutputs(); idx++) { + OutputPort new_value{use_node, idx}; + add(new_value); + } + } + } +} + +void initialPartitionAssignment(Graph *graph) { + graph->walk([&](Node *node) { + if (node->isData() && !node->hasPartition()) { + auto partition = graph->addPartition(); + node->setPartition(partition); + } + }); +} + +SmallVector getCrossingEdges(Graph *graph) { + SmallVector edges; + for (auto &partition : graph->getPartitions()) + for (auto node : partition->getNodes()) + for (auto edge : node->getOutEdges()) { + if (!edge.crossesPartitions()) + continue; + edges.push_back(edge); + } + return edges; +} + +SmallVector getOutCrossingEdges(Partition *partition) { + SmallVector edges; + for (auto node : partition->getNodes()) + for (auto edge : node->getOutEdges()) { + if (!edge.crossesPartitions()) + continue; + edges.push_back(edge); + } + return edges; +} + +void deserializeManualPartitions(Operation *region, Graph *graph) { + std::map manual_partitions; + graph->walk([&](Node *node) { + if (node->isOp()) { + auto op = node->getOp(); + if (op->hasAttr(kPartitionAttrName)) { + auto partitionIds = + cast(op->getAttr(kPartitionAttrName)) + .asArrayRef(); + for (auto id : partitionIds) { + if (manual_partitions.find(id) == manual_partitions.end()) { + auto partition = graph->addPartition(); + partition->addFlag(Flags::MANUAL); + manual_partitions[id] = partition; + LLVM_DEBUG({ + llvm::errs() << "deserialize manual partition:"; + partition->dump(); + }); + } + node->addPartition(manual_partitions[id]); + } + } + } + }); +} + +bool isNone(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags == Flags::NONE || flags == Flags::MANUAL; +} + +bool isOnlyNone(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags == Flags::NONE; +} + +bool isView(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::VIEW; +} + +bool isManual(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::MANUAL; +} + +bool isLoad(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::LOAD; +} + +bool isStore(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::STORE; +} + +bool isMMA(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::MMA; +} + +bool isTMEM(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::TMEM; +} + +bool isSFU(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return flags & Flags::SFU; +} + +bool isCostlySFU(Node *node) { + auto partition = node->getPartition(); + auto flags = partition->getFlags(); + return (flags & Flags::SFU) && partition->getCost() > 256; +} + +bool isForIterArg(Node *node) { + if (node->isOp()) + return false; + auto blockArg = dyn_cast(node->getValue()); + if (!blockArg) + return false; + return isa(blockArg.getOwner()->getParentOp()); +} + +bool isIfResult(Node *node) { + if (node->isOp()) + return false; + auto result = dyn_cast(node->getValue()); + if (!result) + return false; + return isa(result.getOwner()); +} + +SmallVector>> heuristics = { + // load followed by local alloc in same partition + {"load_local_alloc", + [](Edge edge) { + if (!node_isa(edge.getToNode())) { + return false; + } + + if (node_isa( + edge.getFromNode())) { + // require layouts to match for TMA load + alloc + auto load = edge.getFromNode()->getOp(); + auto alloc = cast(edge.getToNode()->getOp()); + return getSharedEncoding(load) == alloc.getType().getEncoding(); + } + + if (node_isa(edge.getFromNode())) { + return true; + } + + return false; + }}, + + // sequence of view ops in same partition + // Note: view ops guaranteed to have been duplicated so there + // is one use/def for each + {"view_sequence", + [](Edge edge) { + auto from = getNodeFlags(edge.getFromNode()); + auto to = getNodeFlags(edge.getToNode()); + return (from & Flags::VIEW) && (to & Flags::VIEW); + }}, + + // merge view op partition with producer if it involves fewer + // elements than merging with the consumer of the view partition + {"view_producer", + [](Edge edge) { + if (!isView(edge.getToNode())) { + return false; + } + auto from = getNodeFlags(edge.getFromNode()); + auto to = getNodeFlags(edge.getToNode()); + if (!(to & Flags::VIEW)) { + return false; + } + + auto view_partition = edge.getToNode()->getPartition(); + auto out_edges = getOutCrossingEdges(view_partition); + assert(out_edges.size() == 1); + auto out_edge = out_edges[0]; + + auto in_size = edge.getSize(); + auto out_size = out_edge.getSize(); + + return in_size > out_size; + }}, + + // merge remaining view op partitions with consumer + // as that involves fewer elements being communicated via aref + {"view_consumer", + [](Edge edge) { + if (!isView(edge.getFromNode())) { + return false; + } + auto from = getNodeFlags(edge.getFromNode()); + auto to = getNodeFlags(edge.getToNode()); + if (!(from & Flags::VIEW)) { + return false; + } + return true; + }}, + + // for op iter arg placed in same partition as op that produces + // its value in the loop body (if it is not a token) + {"for_op_iter_arg", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (from->getParent() != to->getParent()) + // skip if not both in the loop body + return false; + if (!isForIterArg(to)) + // skip is not to an iter arg + return false; + if (isa(to->getValue().getType())) + // skip if a token type + return false; + return true; + }}, + + // for op iter arg placed in same partition as op that consumes + // its value (if it is a token) + {"for_op_iter_arg_token", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (!isForIterArg(from)) + // skip if not from an iter arg + return false; + if (!isa(from->getValue().getType())) + // skip if not a token + return false; + return true; + }}, + + // if op result placed in same partition as MMA op that produces it (if it + // is a token) + {"if_op_result_token", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (!isMMA(from)) { + // skip if not from an MMA + } + if (!isIfResult(to)) + // skip if not to an if op result + return false; + if (!isa(to->getValue().getType())) + // skip if not a token + return false; + return true; + }}, + + // merge expensive SFU ops with their dependencies (except MMA, STORE and + // other SFU) + {"sfu_consumer", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isCostlySFU(to) && !isMMA(from) && !isLoad(from) && !isSFU(from); + }}, + + // straight sequence of NONE ops merges together + {"sequence", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (from->getNumOutDataEdges() > 1 || to->getNumInDataEdges() > 1) + return false; + return isNone(from) && isNone(to); + }}, + + // straight sequence of NONE op to SFU op merges together + {"sequence_sfu", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (from->getNumOutDataEdges() > 1 || to->getNumInDataEdges() > 1) + return false; + return isNone(from) && isSFU(to); + }}, + + // TMEM load merges with consumer + // FIXME: limit to single consumer? + {"tmem_load", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return node_isa(from); + }}, + + // TMEM and STORE groups merge + {"tmem_store", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isTMEM(from) && isStore(to); + }}, + + // NONE/cheap SFU merges with consumer (except LOAD, MMA or costly SFU) + {"none_consumer", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return (isNone(from) || (isSFU(from) && !isCostlySFU(from))) && + !isNone(to) && !isMMA(to) && !isLoad(to) && !isCostlySFU(to); + }}, + + // NONE merges with costly producer (except LOAD or MMA) + // This will prefer to merge NONE nodes into costly groups, rather than + // non-costly groups + // e.g. in the two SFU groups of attention kernels + {"none_producer_costly", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isNone(to) && !isNone(from) && !isMMA(from) && !isLoad(from) && + from->getPartition()->getCost() > 256; + }}, + + // NONE merges with producer (except LOAD or MMA) + {"none_producer", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isNone(to) && !isNone(from) && !isMMA(from) && !isLoad(from); + }}, + + // merge connected STORE partitions together + // these are both using tt.descriptor_store and have a dataflow edge + // between, so avoid communicating between partitions via aref + {"connected_store", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isStore(from) && isStore(to); + }}, + + // merge connected NONE partitions together + {"connected_none", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return isOnlyNone(from) && isOnlyNone(to); + }}, + + // merge connected NONE and MANUAL partitions together + {"connected_none_manual", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return (isOnlyNone(from) && isManual(to)) || + (isOnlyNone(to) && isManual(from)); + }}, + + // merge connected partitions together if edge between is expensive + // TODO: this might be better expressed as a horizontal rule, + // that aims to keep shmem usage under the limit + {"connected", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return !isLoad(from) && !isLoad(to) && !isMMA(from) && !isMMA(to) && + edge.getSize() > 16384; // FIXME: seemingly arbitrary size... + }}, + + // store group not used by an mma/dot op should be merged + {"load_epilog", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + if (!isLoad(from)) + return false; + + SmallVector stack; + DenseSet seen; + stack.push_back(from); + seen.insert(from); + + while (!stack.empty()) { + auto node = stack.back(); + stack.pop_back(); + if (isMMA(node) || (node->isOp() && isa(node->getOp()))) { + return false; + } else { + for (auto edge : node->getOutEdges()) { + if (!seen.contains(edge.getToNode())) { + stack.push_back(edge.getToNode()); + seen.insert(edge.getToNode()); + } + } + } + } + + return true; + }}, +}; + +SmallVector>> constraints = { + // don't merge manual partitions + {"manual", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return !(isManual(from) && isManual(to)); + }}, + + // don't merge partitions with tmem ops into mma partitions + {"tmem_mma", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return !((isMMA(from) && isTMEM(to)) || (isMMA(to) && isTMEM(from))); + }}, + + // don't merge tmem alloc (non-token form) into mma partition + {"tmem_alloc", + [](Edge edge) { + auto from = edge.getFromNode(); + auto to = edge.getToNode(); + return !(node_isa(from) && isMMA(to)); + }}, +}; + +DenseSet getTMEMAllocs(Partition *partition) { + // look for all tmem allocs used by the partition + DenseSet result; + for (auto node : partition->getNodes()) { + if (!node->isOp()) + continue; + Operation *alloc = nullptr; + if (auto load = dyn_cast(node->getOp())) { + alloc = load.getOperand(0).getDefiningOp(); + } + if (auto store = dyn_cast(node->getOp())) { + alloc = store.getOperand(0).getDefiningOp(); + } + if (alloc) { + assert(isa(alloc)); + result.insert(alloc); + } + } + return result; +} + +SmallVector< + std::pair>> + partition_heuristics = { + // merge mma partitions + {"mma", + [](Partition *a, Partition *b) { + auto a_is_mma = (a->getFlags() == Flags::MMA); + auto b_is_mma = (b->getFlags() == Flags::MMA); + return a_is_mma && b_is_mma; + }}, + + // merge load partitions + {"load", + [](Partition *a, Partition *b) { + auto a_is_load = (a->getFlags() == Flags::LOAD); + auto b_is_load = (b->getFlags() == Flags::LOAD); + return a_is_load && b_is_load; + }}, + + // merge none with store partitions + {"none", + [](Partition *a, Partition *b) { + auto a_is_none = (a->getFlags() == Flags::NONE); + auto b_is_none = (b->getFlags() == Flags::NONE); + auto a_is_store = (a->getFlags() & Flags::STORE); + auto b_is_store = (b->getFlags() & Flags::STORE); + return (a_is_none && b_is_store) || (a_is_store && b_is_none); + }}, + + // merge TMEM partitions together, if they use the same tmem alloc + // aref does not support tmem with more than 2 partitions + // and the tmem_alloc'd memory can maximally be used by an MMA + // partition and a TMEM partition + {"tmem", + [](Partition *a, Partition *b) { + auto a_is_tmem = (a->getFlags() & Flags::TMEM); + auto b_is_tmem = (b->getFlags() & Flags::TMEM); + if (!a_is_tmem || !b_is_tmem) + return false; + auto allocs_a = getTMEMAllocs(a); + auto allocs_b = getTMEMAllocs(b); + // if the sets are overlapping, alloc is used by both TMEM partitions + for (auto alloc_a : allocs_a) + if (allocs_b.contains(alloc_a)) + return true; + return false; + }}, +}; + +void mergePartitions(Graph *graph, std::string funcName, + VisualizationInfo &vis_info) { + LLVM_DEBUG({ llvm::errs() << "#### applying heuristics...\n"; }); + + // initial worklist is list of all edges that cross partitions + auto crossingEdges = getCrossingEdges(graph); + bool changed = false; + do { + changed = false; + LLVM_DEBUG({ + llvm::errs() << "\n" + << crossingEdges.size() << " crossing edges remaining\n"; + }); + + for (auto [name, apply] : heuristics) { + for (auto it = crossingEdges.begin(); it != crossingEdges.end();) { + auto edge = *it; + + // remove edges that no longer cross partitions from the worklist + if (!edge.crossesPartitions()) { + it = crossingEdges.erase(it); + continue; + } + + if (apply(edge)) { + // check if applying the heuristic will satisfy the constraints + bool ok = true; + for (auto [name, constraint] : constraints) { + if (!constraint(edge)) { + ok = false; + break; + } + } + if (!ok) { + it++; + continue; + } + + LLVM_DEBUG({ + llvm::dbgs() << "\napply heuristic \"" << name << "\"\n"; + llvm::dbgs() << edge.getFromNode()->getLabel() << " -> " + << edge.getToNode()->getLabel() << "\n"; + llvm::dbgs() << "partitions " << edge.getFromNode()->getPartition() + << " -> " << edge.getToNode()->getPartition() << "\n"; + llvm::dbgs() << "flags " + << edge.getFromNode()->getPartition()->getFlags() + << " -> " + << edge.getToNode()->getPartition()->getFlags() + << "\n"; + }); + + // merge the partitions + auto from_partition = edge.getFromNode()->getPartition(); + auto to_partition = edge.getToNode()->getPartition(); + Partition::merge(from_partition, to_partition); + + visualize(funcName, "merge-step", std::string("merge: rule ") + name, + graph, vis_info); + crossingEdges.erase(it); + + changed = true; + break; + } + + it++; + } + if (changed) + break; + } + } while (changed); + + visualize(funcName, "merge-step", "edge based merge complete", graph, + vis_info); + + { + // look at every pair of partitions and check if they should be merged + auto merge_partitions_step = [&]() { + SmallVector all_partitions; + for (auto partition : graph->getPartitions()) + all_partitions.push_back(partition); + for (auto [name, apply] : partition_heuristics) { + for (auto partitionA : all_partitions) { + for (auto partitionB : all_partitions) { + if (partitionA == partitionB) + continue; + if (apply(partitionA, partitionB)) { + LLVM_DEBUG({ + llvm::errs() << "\nmerge \"" << name << "\" ----\n"; + partitionA->dump(); + partitionB->dump(); + }); + Partition::merge(partitionA, partitionB); + visualize(funcName, "merge-step", + std::string("merge: rule ") + name, graph, vis_info); + return false; + } + } + } + } + return true; + }; + + while (true) { + if (merge_partitions_step()) + break; + } + } + + visualize(funcName, "merge-step", "partition based merge complete", graph, + vis_info); + + LLVM_DEBUG({ llvm::errs() << "\n#### heuristics done\n"; }); +} + +void propagatePartitions(Graph *graph, std::string funcName, + VisualizationInfo &vis_info) { + visualize(funcName, "propagate", "before propagate", graph, vis_info); + + // propagate partitions to parent ops + SmallVector leaves; + + graph->walk([&](Node *node) { + // node is a leaf if it has a region, + // and none of the ops in the region are leaves + bool is_leaf = !node->getNodes().empty(); + for (auto &child : node->getNodes()) { + if (!child->getNodes().empty()) { + is_leaf = false; + break; + } + } + if (is_leaf) + leaves.push_back(node); + }); + + bool changed = true; + while (changed) { + for (auto leaf : leaves) { + // partitions for leaf are union of partitions of all ops contained in + // the leaf + SetVector partitions; + for (auto &node : leaf->getNodes()) + partitions.insert(node->getPartitions().begin(), + node->getPartitions().end()); + leaf->addPartitions(partitions); + + // propagate to parent nodes + auto node = leaf->getParent(); + while (node) { + // include union of partitions of ops in the parent + for (auto &child : node->getNodes()) + partitions.insert(child->getPartitions().begin(), + child->getPartitions().end()); + node->addPartitions(partitions); + node = node->getParent(); + } + } + + // propagate partitions to non-data nodes + { + SmallVector nodes; + // include nodes with regions + graph->walk([&](Node *node) { + if (!node->getNodes().empty()) + nodes.push_back(node); + }); + // include data nodes + for (auto &partition : graph->getPartitions()) + for (auto &node : partition->getNodes()) + if (node->isData()) + nodes.push_back(node); + + changed = false; + for (auto node : nodes) { + SmallVector stack; + DenseSet seen; + auto partitions = node->getPartitions(); + stack.push_back(node); + seen.insert(node); + + while (!stack.empty()) { + auto node = stack.back(); + stack.pop_back(); + + auto propagate = [&](Edge edge, Node *node) { + if (!node || node->isData()) + return; + auto numPartitionsBefore = node->getPartitions().size(); + node->addPartitions(partitions); + auto numPartitionsAfter = node->getPartitions().size(); + changed |= (numPartitionsBefore != numPartitionsAfter); + if (seen.count(node) == 0) { + stack.push_back(node); + seen.insert(node); + } + }; + + for (auto edge : node->getInEdges()) + propagate(edge, edge.getFromNode()); + } + } + } + } + + visualize(funcName, "propagate", "after propagate", graph, vis_info); + + // propagate partitions to non-data nodes (forward) + { + SmallVector nodes; + // get nodes that have no partition assigned + graph->walk([&](Node *node) { + if (!node->hasPartition()) + nodes.push_back(node); + }); + + changed = false; + while (!nodes.empty()) { + // try propagating partitions forward to nodes with no partition + int start_size = nodes.size(); + bool changed = false; + for (auto node : nodes) { + for (auto edge : node->getInEdges()) { + if (!edge.getFromNode()) + continue; + if (edge.getFromNode()->hasPartition()) { + for (auto partition : edge.getFromNode()->getPartitions()) + node->setPartition(partition); + changed = true; + } + } + } + // remove all nodes that now have a partition + nodes.erase( + std::remove_if(nodes.begin(), nodes.end(), + [](Node *node) { return node->hasPartition(); }), + nodes.end()); + int end_size = nodes.size(); + if (start_size == end_size) { + // no change -> exit + break; + } + } + } + + visualize(funcName, "propagate", "propagate forward", graph, vis_info); + + // propagate partitions of tt.reduce into its body + graph->walk([&](Node *node) { + if (node->isOp() && isa(node->getOp())) { + auto partitions = node->getPartitions(); + node->walk( + [&](Node *child_node) { child_node->addPartitions(partitions); }); + } + }); + + visualize(funcName, "propagate", "propagate reduce", graph, vis_info); + + // Corner case: tmem store following tmem alloc should be in a warp + // partition with 4 warps (i.e. a non-mma partition) + // This fixes the case where in a tmem alloc + initial store that feeds into + // an mma, the store is propagated the partition of the mma. It should instead + // have the same partition as the alloc + SmallVector patched_nodes; + + graph->walk([&](Node *node) { + if (node->isData() || !node->isOp() || + !isa(node->getOp())) { + return; + } + + Node *alloc = nullptr; + for (auto edge : node->getInEdges()) { + if (edge.getToIdx() == 1) { // token edge + alloc = edge.getFromNode(); + break; + } + } + if (!alloc || !alloc->isOp() || !isa(alloc->getOp())) + return; + + // pick the first non-mma partition + // does nothing if the only partitions are mma + auto partitions = alloc->getPartitions(); + for (auto partition : partitions) { + if (partition->getFlags() & MMA) + continue; + node->setPartition(partition); + patched_nodes.push_back(node); + break; + } + }); + + visualize(funcName, "propagate", "tmem store corner case", graph, vis_info); + + // propagate partitions for patched up nodes to non-data nodes + for (auto node : patched_nodes) { + SmallVector stack; + DenseSet seen; + auto partitions = node->getPartitions(); + stack.push_back(node); + seen.insert(node); + + while (!stack.empty()) { + auto node = stack.back(); + stack.pop_back(); + + for (auto edge : node->getInEdges()) { + if (edge.isDataValue()) + continue; + auto fromNode = edge.getFromNode(); + if (!fromNode) + continue; + fromNode->addPartitions(partitions); + + if (seen.count(edge.getFromNode()) == 0) { + stack.push_back(fromNode); + seen.insert(fromNode); + } + } + } + } +} + +void duplicateCheapOps(Graph *graph, std::string funcName, + VisualizationInfo &vis_info) { + visualize(funcName, "duplicate", "before duplicate cheap ops", graph, + vis_info); + + // for each partition: + // look at all crossing edges leaving the partition + // do a depth first search through NONE nodes, if we hit the same partition + // assign all nodes on that path to the partition + for (auto partition : graph->getPartitions()) { + + auto crossingEdges = getOutCrossingEdges(partition); + + for (auto edge : crossingEdges) { + // only handle start nodes with a single partition + if (edge.getFromNode()->getPartitions().size() != 1) + continue; + auto startPartition = edge.getFromNode()->getPartition(); + + // only handle nodes with a single partition + auto start = edge.getToNode(); + if (start->getPartitions().size() != 1) + continue; + auto partition = start->getPartition(); + + auto isCandidate = [](Node *node) { + return (getNodeFlags(node) == Flags::NONE || + getNodeFlags(node) == Flags::SFU); + }; + + if (!isCandidate(edge.getToNode())) + continue; + + auto update = [&]() { + std::map parentMap; + + SmallVector stack; + stack.push_back(start); + DenseSet seen; + + while (!stack.empty()) { + auto node = stack.back(); + stack.pop_back(); + if (!seen.contains(node)) { + seen.insert(node); + for (auto edge : node->getOutEdges()) { + auto child = edge.getToNode(); + if (!seen.contains(child)) { + if (child->getPartitions().size() != 1 || !isCandidate(child)) { + // do nothing + } else if (child->getPartition() == partition) { + parentMap.emplace(child, node); + stack.push_back(child); + } else if (child->getPartition() == startPartition) { + // found a path, set all nodes on the path to the partition + node->addPartition(startPartition); + while (parentMap.find(node) != parentMap.end()) { + node = parentMap[node]; + node->addPartition(startPartition); + } + + visualize(funcName, "duplicate", "duplicate cheap ops", graph, + vis_info); + + return; + } + } + } + } + } + }; + update(); + } + } + + visualize(funcName, "duplicate", "duplicate cheap ops done", graph, vis_info); +} + +void serialize(size_t idx, Operation *region, Graph *graph) { + + SetVector alreadyWritten; + + auto context = graph->getRoot()->getOp()->getContext(); + Builder b(context); + + // annotate loop with index + region->setAttr(kWarpSpecializeTagAttrName, b.getI32IntegerAttr(idx)); + + auto setPartitionsAttr = [&](Operation *op, Node *node) { + // not for func op + if (isa(op)) + return; + + // Note: we may have multiple nodes per op, so we merge the partition + // ids for all nodes of the op + SetVector partitionIds; + if (alreadyWritten.contains(op)) { + // if we already serialized a node to this op, merge those partition ids + // with the node being serialized + partitionIds = getPartitionIds(op); + } + alreadyWritten.insert(op); + for (auto partition : node->getPartitions()) + partitionIds.insert(*partition->id); + auto partitionIdsList = partitionIds.takeVector(); + std::sort(partitionIdsList.begin(), partitionIdsList.end()); + auto partitionsAttr = b.getDenseI32ArrayAttr(partitionIdsList); + op->setAttr(kPartitionAttrName, partitionsAttr); + + // set same paritions in yield ops + if (auto forOp = dyn_cast(op)) { + cast(forOp.getBody()->getTerminator()) + ->setAttr(kPartitionAttrName, partitionsAttr); + } else if (auto ifOp = dyn_cast(op)) { + ifOp.thenYield()->setAttr(kPartitionAttrName, partitionsAttr); + if (!ifOp.getElseRegion().empty()) { + ifOp.elseYield()->setAttr(kPartitionAttrName, partitionsAttr); + } + } + }; + + auto setPartitionOutputsAttr = [&](Operation *op, size_t idx, size_t size, + Node *node) { + llvm::SmallVector partitionAttrs; + if (op->hasAttr(kPartitionOutputsAttrName)) { + // get existing partitions + for (auto attr : + op->getAttrOfType(kPartitionOutputsAttrName)) { + partitionAttrs.push_back(attr); + } + assert(partitionAttrs.size() == size); + } else { + // initialize to no partitions + for (size_t i = 0; i < size; i++) + partitionAttrs.push_back(b.getDenseI32ArrayAttr({})); + } + + // update partitions for this output + SmallVector partitions; + for (auto partition : node->getPartitions()) + partitions.push_back(*partition->id); + std::sort(partitions.begin(), partitions.end()); + partitionAttrs[idx] = b.getDenseI32ArrayAttr(partitions); + op->setAttr(kPartitionOutputsAttrName, + ArrayAttr::get(context, partitionAttrs)); + }; + + graph->walk([&](Node *node) { + if (node->isOp()) { + setPartitionsAttr(node->getOp(), node); + + if (auto ret = dyn_cast(node->getOp())) { + // result of a reduce + auto reduce = node->getParent()->getOp(); + setPartitionOutputsAttr(reduce, 0, 1, node); + } + + } else { + auto value = node->getValue(); + if (auto blockArg = dyn_cast(value)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + // nothing for func ops + } else if (auto forOp = dyn_cast(parentOp)) { + if (blockArg.getArgNumber() == 0) { + // nothing for induction variable + } else { + // for op iter args + setPartitionOutputsAttr(parentOp, blockArg.getArgNumber() - 1, + forOp.getResultTypes().size(), node); + } + } else { + assert(false); + } + } else if (auto result = dyn_cast(value)) { + auto op = result.getOwner(); + if (isa(op)) { + // do nothing (handled by block arg) + } else if (auto ifOp = dyn_cast(op)) { + // result of an if + setPartitionOutputsAttr(op, result.getResultNumber(), + ifOp.getResultTypes().size(), node); + } else { + assert(false); + } + } else { + assert(false); + } + } + }); + + // set stages + SmallVector stages; + for (auto &partition : graph->getPartitions()) { + auto id = *partition->id; + while (id >= stages.size()) + stages.push_back(b.getI32IntegerAttr(0)); + stages[id] = b.getI32IntegerAttr(partition->getStage()); + } + region->setAttr(kPartitionStagesAttrName, b.getArrayAttr(stages)); +} + +void duplicateViewOps(Graph *graph) { + // Ensure all view ops (e.g. broadcast/expand dims) have a single user, + // by duplicating nodes where necessary + + SmallVector viewOps; + + graph->walk([&](Node *node) { + if (node->isData() && node->isOp() && isViewOp(node->getOp())) + viewOps.push_back(node); + }); + + while (!viewOps.empty()) { + auto node = viewOps.pop_back_val(); + auto op = node->getOp(); + + assert(op->getResults().size() == 1); + + auto outEdges = node->getOutEdges(); + + bool first = true; + for (auto edge : outEdges) { + if (!first) { + auto newNode = node->getParent()->addNode(op, op->getNumOperands(), + op->getNumResults()); + + // remove old edge + Node::removeEdge(edge); + + // add new edge + OutputPort outputPort(newNode, 0); + OutputPort inputPort(edge.getToNode(), edge.getToIdx()); + Node::addEdge(outputPort, inputPort); + + // add operands of new node + for (auto inEdge : node->getInEdges()) { + Node::addEdge(inEdge.getFrom(), + InputPort(newNode, inEdge.getToIdx())); + } + + // copy data values + for (auto idx = 0; idx < op->getNumResults(); idx++) { + if (node->isDataValue(idx)) { + newNode->setDataValue(idx); + } + } + } + first = false; + } + } +} + +void assignPartitionIds(Graph *graph) { + size_t idx = 0; + + SmallVector store_partitions; + SmallVector mma_partitions; + SmallVector load_partitions; + SmallVector other_partitions; + + for (auto partition : graph->getPartitions()) { + if (partition->getFlags() & Flags::STORE) + store_partitions.push_back(partition); + else if (partition->getFlags() & Flags::MMA) + mma_partitions.push_back(partition); + else if (partition->getFlags() & Flags::LOAD) + load_partitions.push_back(partition); + else + other_partitions.push_back(partition); + } + + for (auto partition : other_partitions) { + partition->id = idx; + idx++; + } + for (auto partition : store_partitions) { + partition->id = idx; + idx++; + } + // ensure MMA and LOAD partitions are never the same as the default + // partition + if (idx == 0) + idx++; + for (auto partition : mma_partitions) { + partition->id = idx; + idx++; + } + for (auto partition : load_partitions) { + partition->id = idx; + idx++; + } +} + +void assignPartitionsForOpsWithNoUse(Graph *graph) { + // nodes with no partition placed in same partition as other ops in the + // region or default partition if none. Note: we can't just use partitions + // of parent op, as this includes things like tmem tokens + Partition *defaultPartition = nullptr; + for (auto partition : graph->getPartitions()) + if (partition->id && *partition->id == 0) + defaultPartition = partition; + graph->walk([&](Node *node) { + if (node->getPartitions().empty()) { + bool done = false; + auto parent = node->getParent(); + if (parent && parent->isOp()) { + for (auto &otherNode : parent->getNodes()) { + if (node == otherNode.get()) + continue; + if (otherNode->isOp() && otherNode->hasPartition()) { + node->addPartitions(otherNode->getPartitions()); + done = true; + } + } + } + if (!done) { + if (defaultPartition == nullptr) { + // default partition doesn't exist, create one + defaultPartition = graph->addPartition(); + defaultPartition->id = 0; + } + node->setPartition(defaultPartition); + } + } + }); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct PartitionScheduling + : public impl::TritonGPUPartitionSchedulingBase { + using TritonGPUPartitionSchedulingBase::TritonGPUPartitionSchedulingBase; + + void runOnOperation() override { + // find ops to partition + SmallVector ops; + getOperation().walk([&](scf::ForOp op) { + if (op->hasAttr(kWarpSpecializeAttrName)) + ops.push_back(op); + }); + + // run partitioner on each op + size_t idx = 0; + for (auto op : ops) { + analyze(idx, op); + cloneMultiPartitionDataOps(op); + idx++; + } + } + +private: + void analyze(size_t idx, Operation *op) { + using namespace partition_scheduling_detail; + + auto func = op->getParentOfType(); + + VisualizationInfo vis_info; + auto key = func.getSymName().str() + "_" + std::to_string(idx); + + auto graph = buildGraph(op); + visualize(key, "input", "input", graph.get(), vis_info); + auto initValues = initialDataValues(graph.get()); + propagateDataValues(initValues); + visualize(key, "input", "after data values", graph.get(), vis_info); + duplicateViewOps(graph.get()); + visualize(key, "input", "after duplicate view ops", graph.get(), vis_info); + deserializeManualPartitions(op, graph.get()); + visualize(key, "input", "final", graph.get(), vis_info); + + initialPartitionAssignment(graph.get()); + visualize(key, "initial", "initial partitions", graph.get(), vis_info); + mergePartitions(graph.get(), key, vis_info); + visualize(key, "merge", "merged", graph.get(), vis_info); + propagatePartitions(graph.get(), key, vis_info); + visualize(key, "propagate", "propagated", graph.get(), vis_info); + + assignPartitionIds(graph.get()); + visualize(key, "assign-partition-ids", "assign partition ids", graph.get(), + vis_info); + // Handle case where ops with no uses (like llvm.intr.assume) get no + // partition assigned + assignPartitionsForOpsWithNoUse(graph.get()); + visualize(key, "assign-no-use", "assign no use", graph.get(), vis_info); + propagatePartitions(graph.get(), key, vis_info); + visualize(key, "propagate", "propagated", graph.get(), vis_info); + // Optimization: looks for paths of NONE ops with low cost, from one + // partition, through another partition, and back to the same partition. + // Duplicates these to avoid the aref involved (i.e. assign to both + // partitions) + duplicateCheapOps(graph.get(), key, vis_info); + visualize(key, "final", "final", graph.get(), vis_info); + + LLVM_DEBUG({ + llvm::errs() << "\nfinal partitions:\n"; + for (auto &partition : graph->getPartitions()) + partition->dump(); + }); + + serialize(idx, op, graph.get()); + } + + void cloneMultiPartitionDataOps(Operation *region) { + // FIXME: this transformation runs after the partition scheduling is + // complete It clones "data" ops with multiple partitions assigned, as + // insert-aref pass cannot currently handly these. E.g. an op assigned to + // partitions 0,1 will be cloned into two ops, one in partition 0 and the + // other in partition 1 and all uses are updated correctly. + + using namespace partition_scheduling_detail; + + // build data flow graph to find all data ops + DenseSet dataOps; + { + auto graph = buildGraph(region); + auto initValues = initialDataValues(graph.get()); + propagateDataValues(initValues); + graph->walk([&](Node *node) { + if (node->isOp() && node->isData()) + dataOps.insert(node->getOp()); + }); + } + + // for each partition, find all data ops that are in that partition, + // and in another partition + for (auto partition : getPartitionIds(region)) { + SetVector partitionSet; + partitionSet.insert(partition); + + SmallVector ops; + region->walk([&](Operation *op) { + auto partitions = getPartitionIds(op); + if (partitions.contains(partition) && partitions.size() > 1 && + dataOps.contains(op)) + ops.push_back(op); + }); + + SmallVector oldOps; + SetVector newOps; + DenseMap mapping; + for (auto op : ops) { + auto newOp = OpBuilder(op).clone(*op); + setPartition(newOp, partitionSet); + oldOps.push_back(op); + newOps.insert(newOp); + mapping[newOp] = op; + mapping[op] = newOp; + } + + // rewrite operands + // if op that produces operand of new op is has a duplicated op, + // rewrite the operand to use that op + for (auto newOp : newOps) { + for (auto &operand : newOp->getOpOperands()) { + auto value = operand.get(); + if (isa(value)) { + auto result = cast(value); + auto producerOp = result.getOwner(); + if (mapping.contains(producerOp)) { + auto newProducerOp = mapping[producerOp]; + auto newValue = + newProducerOp->getResult(result.getResultNumber()); + auto idx = operand.getOperandNumber(); + newOp->setOperand(idx, newValue); + } + } + } + } + + // rewrite results + for (auto newOp : newOps) { + auto oldOp = mapping[newOp]; + for (auto &use : oldOp->getUses()) { + auto user = use.getOwner(); + assert(user); + auto userPartitions = getPartitionIds(user); + // skip if use is not in same partition as new op + if (userPartitions != partitionSet) + continue; + // update the use to use the new op + auto result = cast(use.get()); + auto idx = result.getResultNumber(); + use.set(newOp->getResult(idx)); + } + } + + // remove dead code + bool done = false; + while (!done) { + done = true; + auto op = oldOps.begin(); + for (; op != oldOps.end(); op++) { + if ((*op)->getUses().empty()) { + (*op)->erase(); + oldOps.erase(op); + done = false; + break; + } + } + } + } + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionSchedulingUtility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionSchedulingUtility.cpp new file mode 100644 index 0000000000..3fdf3a92f9 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionSchedulingUtility.cpp @@ -0,0 +1,370 @@ +#include "triton/Dialect/TritonGPU/Transforms/PartitionSchedulingUtility.h" +#include "mlir/Support/LLVM.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include +#include + +namespace mlir::triton::gpu::partition_scheduling_detail { + +llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, Flags flags) { + std::vector strs; + if (flags == Flags::NONE) { + strs.push_back("NONE"); + } else { + if (flags & Flags::MANUAL) + strs.push_back("MANUAL"); + if (flags & Flags::LOAD) + strs.push_back("LOAD"); + if (flags & Flags::STORE) + strs.push_back("STORE"); + if (flags & Flags::MMA) + strs.push_back("MMA"); + if (flags & Flags::TMEM) + strs.push_back("TMEM"); + if (flags & Flags::SFU) + strs.push_back("SFU"); + if (flags & Flags::VIEW) + strs.push_back("VIEW"); + } + for (size_t i = 0; i < strs.size(); i++) { + if (i != 0) + stream << "|"; + stream << strs[i]; + } + return stream; +} + +Flags getNodeFlags(Node *node) { + if (node->isOp()) { + auto op = node->getOp(); + + // if it is manually tagged with a node type + if (op->hasAttr("store")) + return Flags::STORE; + + if (isa(op)) + return Flags::LOAD; + if (isa(op)) + return Flags::STORE; + if (isa(op) || op->hasAttr("mma")) + return Flags::MMA; + if (isa(op)) + return Flags::TMEM; + if (isa(op)) + return Flags::SFU; + if (isViewOp(op)) + return Flags::VIEW; + } + return Flags::NONE; +} + +size_t computeCost(Operation *op) { + if (auto mma = dyn_cast(op)) { + auto a = mma.getA(); + auto b = mma.getB(); + auto a_shape = a.getType().getShape(); + auto b_shape = b.getType().getShape(); + assert(a_shape.size() == 2); + assert(b_shape.size() == 2); + auto M = a_shape[0]; + auto N = b_shape[0]; + auto K = a_shape[1]; + auto cycles = M * N * K / 8192; + return cycles; + } + + if (isa(op)) { + int elementCount = 0; + for (Type type : op->getResultTypes()) { + if (auto tensorTy = dyn_cast(type)) + elementCount += tensorTy.getNumElements(); + } + return elementCount; + } + + return 0; +} + +void Partition::add(Node *node) { + auto node_flags = getNodeFlags(node); + + // Note: only set view flag for partition, + // if it consists of all view ops + // FIXME: have a set kinds of flag to make this generic? + bool all_view = true; + if (!nodes.empty() && !(flags & Flags::VIEW)) + all_view = false; + if (!(node_flags & Flags::VIEW)) + all_view = false; + + nodes.insert(node); + + flags |= node_flags; + if (!all_view) + flags = static_cast(flags & ~Flags::VIEW); + + if (node->hasCost()) + cost += node->getCost(); +} + +void Partition::merge(Partition *lhs, Partition *rhs) { + assert(lhs != rhs); + + // Should never be merging MANUAL partitions + assert(!((lhs->getFlags() & Flags::MANUAL) && + (rhs->getFlags() & Flags::MANUAL))); + + // Always keep the MANUAL partition, + // and prefer emptying the NONE partition + if (lhs->getFlags() & Flags::MANUAL || rhs->getFlags() == Flags::NONE) + std::swap(lhs, rhs); + + auto nodes = lhs->getNodes(); + for (auto node : nodes) { + node->setPartition(rhs); + } + + // remove the now empty partition + lhs->graph->erasePartition(lhs); +} + +void Partition::dump() const { + llvm::errs() << "Partition@" << this << " {\n" + << " id=" << id << "\n" + << " size=" << nodes.size() << "\n" + << " cost=" << cost << "\n" + << " flags=" << flags << "\n" + << "}\n"; +} + +bool Edge::isDataValue() const { + if (!from.getNode()) + return false; + return from.getNode()->isDataValue(from.getIdx()); +} + +bool Edge::crossesPartitions() const { + if (!isDataValue()) + return false; + if (!from.getNode()->hasPartition() || !to.getNode()->hasPartition()) + return false; + // FIXME: only considers edges between nodes assigned to single partitions + // as crossing a boundary + if (from.getNode()->getPartitions().size() != 1 || + to.getNode()->getPartitions().size() != 1) + return false; + return from.getNode()->getPartition() != to.getNode()->getPartition(); +} + +Type Edge::getType() const { + auto fromNode = from.getNode(); + if (fromNode->isOp()) + return fromNode->getOp()->getResult(from.getIdx()).getType(); + return fromNode->getValue().getType(); +} + +size_t Edge::getSize() const { + auto type = getType(); + + if (auto tensor = dyn_cast(type)) { + size_t size = 1; + for (auto x : tensor.getShape()) + size *= x; + return size; + } + + if (auto memdesc = dyn_cast(type)) { + size_t size = 1; + for (auto x : memdesc.getShape()) + size *= x; + return size; + } + + return 1; +} + +void visualize(std::string key, std::string filename, std::string title, + Graph *graph, VisualizationInfo &info) { + + if (!tools::getBoolEnv("TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT")) + return; + + const auto dump_data_only = + tools::getBoolEnv("TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY"); + const auto dump_loop_only = + tools::getBoolEnv("TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY"); + + static std::map keys; + if (keys.find(key) == keys.end()) { + keys[key] = 0; + } + auto idx = keys[key]; + keys[key]++; + + std::stringstream path; + path << "graph-" << key << "-" << std::setfill('0') << std::setw(4) << idx + << "-" << filename << ".dot"; + + std::error_code err; + llvm::raw_fd_ostream dot(path.str(), err); + assert(!err); + + dot << "digraph G {\n"; + dot << "label = \"" << title << "\";\n"; + dot << "labelloc=\"t\";\n"; + dot << "labeljust=\"c\";\n"; + + DenseMap node_ids; + + auto getPartitionId = [&](Partition *partition) { + if (info.partition_ids.count(partition) == 0) + info.partition_ids[partition] = info.partition_ids.size(); + return info.partition_ids[partition]; + }; + + auto getPartitionColor = [&](Partition *partition) { + if (info.partition_colors.count(partition) == 0) { + size_t color = info.partition_colors.size() + 1; + color = (color % 12) + 1; + info.partition_colors[partition] = + std::string("/set312/") + std::to_string(color); + } + return info.partition_colors[partition]; + }; + + // add nodes + std::function visitNodes = [&](Node *graph) { + for (auto &node_obj : graph->getNodes()) { + auto node = node_obj.get(); + + if (dump_data_only && !node->isData() && !node->containsData()) + // skip if dumping data nodes only, and this op is non-data or doesn't + // contain a data node + continue; + if (dump_loop_only && !node->inLoopBody() && !node->containsLoopBody()) + // skip if dumping loop body nodes only + continue; + + node_ids[node] = node_ids.size(); + + if (!node->getNodes().empty()) + dot << "subgraph cluster_cx" << node_ids[node] << " {\n" + << "label=\"\"\n"; + dot << "x" << node_ids[node] << "[shape=plaintext, "; + if (node->isData()) + dot << "color=blue, "; + dot << "label=<"; + dot << ""; + if (node->getNumInputs() > 1) { + dot << ""; + for (size_t idx = 0; idx < node->getNumInputs(); idx++) + dot << ""; + dot << ""; + } + dot << ""; + + if (node->hasCost()) { + dot << " 0) + dot << " COLSPAN=\"" << colspan << "\""; + dot << ">"; + dot << "cost:" << node->getCost(); + dot << ""; + } + + if (node->getNumOutputs() > 1) { + dot << ""; + for (size_t idx = 0; idx < node->getNumOutputs(); idx++) + dot << ""; + dot << ""; + } + dot << "
" << idx << "
getNumInputs(), node->getNumOutputs()); + if (colspan > 0) + dot << " COLSPAN=\"" << colspan << "\""; + dot << ">"; + + dot << ""; + if (node->hasPartition()) { + for (auto partition : node->getPartitions()) { + auto name = std::to_string(getPartitionId(partition)); + dot << ""; + } + } + dot << "
" + << name << "{" << partition->getCost() << "}" + << "[" << partition->getFlags() << "]" << node->getLabel(); + if (node->isData()) + dot << " [" << getNodeFlags(node) << "]"; + dot << "
"; + dot << "
" << idx << "
>];\n"; + if (!node->getNodes().empty()) { + visitNodes(node); + dot << "}\n"; + } + } + }; + visitNodes(graph->getRoot()); + + // add edges + std::function visitEdges = [&](Node *node) { + size_t idx = 0; + for (auto inputPorts : node->getOutputs()) { + OutputPort outputPort{node, idx}; + for (auto inputPort : inputPorts) { + Edge edge(outputPort, inputPort); + if (node_ids.count(outputPort.getNode()) == 0 || + node_ids.count(inputPort.getNode()) == 0) + continue; + dot << "x" << node_ids[outputPort.getNode()]; + dot << ":"; + if (outputPort.getNode()->getNumOutputs() == 1) + dot << "inout"; + else + dot << "out" << outputPort.getIdx(); + dot << " -> "; + dot << "x" << node_ids[inputPort.getNode()]; + dot << ":"; + if (inputPort.getNode()->getNumInputs() == 1) + dot << "inout"; + else + dot << "in" << inputPort.getIdx(); + std::vector attrs; + if (edge.isDataValue()) { + if (edge.getFromNode()->getPartitions().size() > 1 || + edge.getToNode()->getPartitions().size() > 1) + // invalid edge, should only have one partition + attrs.push_back("color=\"green\""); + else if (edge.crossesPartitions()) + attrs.push_back("color=\"red\""); + else + attrs.push_back("color=\"blue\""); + auto size = edge.getSize(); + if (size != 1) { + attrs.push_back("label=\"" + std::to_string(size) + "\""); + } + } + if (!attrs.empty()) { + dot << "["; + for (auto attr = attrs.begin(); attr != attrs.end(); attr++) { + if (attr != attrs.begin()) { + dot << ","; + } + dot << *attr; + } + dot << "]"; + } + dot << ";\n"; + } + idx++; + } + for (auto &node : node->getNodes()) + visitEdges(node.get()); + }; + visitEdges(graph->getRoot()); + + dot << "}\n"; +} + +} // namespace mlir::triton::gpu::partition_scheduling_detail diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonInstrument/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/CMakeLists.txt new file mode 100644 index 0000000000..6b39e076d6 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonInstrumentIR + Dialect.cpp + FunctionBuilder.cpp + Ops.cpp + Utility.cpp + + DEPENDS + TritonInstrumentTableGen + + LINK_LIBS PUBLIC + MLIRIR + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Dialect.cpp new file mode 100644 index 0000000000..d00906f30f --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Dialect.cpp @@ -0,0 +1,17 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" + +#include "triton/Dialect/TritonInstrument/IR/Dialect.cpp.inc" +using namespace mlir::triton::instrument; + +void TritonInstrumentDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonInstrument/IR/Ops.cpp.inc" + >(); + addInterfaces(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp new file mode 100644 index 0000000000..0c50ebd805 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -0,0 +1,2110 @@ +#include "triton/Dialect/TritonInstrument/IR/FunctionBuilder.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir::triton::instrument { + +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace tti = mlir::triton::instrument; + +namespace { + +namespace BarrierBits { +constexpr unsigned phaseBit = 0; +constexpr unsigned initCountLsb = 1; +constexpr unsigned currentCountLsb = 9; +constexpr unsigned countBitWidth = 8; +constexpr unsigned countMask = (1u << countBitWidth) - 1; +} // namespace BarrierBits + +namespace WaitingBits { +constexpr unsigned bitsPerThread = 2; +constexpr unsigned flagBit = 0; +constexpr unsigned phaseBit = 1; + +constexpr uint32_t makeInterleavedMask(unsigned bit) { + uint32_t mask = 0; + for (unsigned i = 0; i < tti::NUM_THREADS; ++i) + mask |= 1u << (bitsPerThread * i + bit); + return mask; +} + +constexpr uint32_t flagMask = makeInterleavedMask(flagBit); +constexpr uint32_t phaseMask = makeInterleavedMask(phaseBit); +} // namespace WaitingBits + +// Information about the optional assert message and tensor type to check. +struct AssertInfo { + StringRef message; + Type type; +}; + +static uint64_t expandActiveMask(uint64_t activeMask) { + uint64_t expanded = 0; + for (unsigned i = 0; i < tti::NUM_THREADS; ++i) { + if (activeMask & (1ull << i)) + expanded |= + 1ull << (WaitingBits::bitsPerThread * i + WaitingBits::flagBit); + } + return expanded; +} + +Value createCmpIntTensorScalar( + ImplicitLocOpBuilder &b, Value tensor, Value scalar, + arith::CmpIPredicate predicate = arith::CmpIPredicate::eq) { + auto tensorTy = cast(tensor.getType()); + Value splat = triton::SplatOp::create(b, tensorTy, scalar); + return arith::CmpIOp::create(b, predicate, tensor, splat); +} + +Value createBitwiseOrReduce(ImplicitLocOpBuilder &b, Value tensor, int axis) { + OpBuilder::InsertionGuard guard(b); + auto tensorType = cast(tensor.getType()); + auto reduceOp = triton::ReduceOp::create(b, std::vector{tensor}, axis); + auto ®ion = reduceOp.getRegion(); + auto &block = region.emplaceBlock(); + block.addArguments({tensorType.getElementType(), tensorType.getElementType()}, + {b.getLoc(), b.getLoc()}); + b.setInsertionPointToStart(&block); + auto result = + arith::OrIOp::create(b, block.getArgument(0), block.getArgument(1)); + triton::ReduceReturnOp::create(b, std::vector{result}); + return reduceOp->getResult(0); +} + +FuncOp getOrCreateFunction( + ModuleOp module, const std::string &name, llvm::ArrayRef argTypes, + ManglingArgs specializationArgs, int numWarps, Type assertType, + std::function buildBody) { + ManglingArgs manglingArgs; + manglingArgs.append(argTypes); + manglingArgs.append(specializationArgs); + if (assertType) { + manglingArgs.append(assertType); + } + std::string funcName = manglingArgs.mangle(name, numWarps); + if (auto existing = module.lookupSymbol(funcName)) { + return existing; + } + + OpBuilder moduleBuilder(module.getContext()); + moduleBuilder.setInsertionPointToStart(module.getBody()); + Location loc = module.getLoc(); + SmallVector resultTypes = {}; + if (assertType) { + resultTypes.push_back(assertType); + } + auto funcType = moduleBuilder.getFunctionType(argTypes, resultTypes); + FuncOp func = FuncOp::create(moduleBuilder, loc, funcName, funcType); + func.setVisibility(SymbolTable::Visibility::Private); + func->setAttr(ttg::AttrNumWarpsName, + moduleBuilder.getI32IntegerAttr(numWarps)); + Block *entryBlock = func.addEntryBlock(); + OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entryBlock); + ImplicitLocOpBuilder fb(loc, bodyBuilder); + buildBody(fb, entryBlock); + return func; +} + +// Create a call to a function with body given by `buildBody`. +// If the function does not exist, it will be created, otherwise the +// existing function will be used. +// If `assertInfo` is provided, the function should return a tensor of +// the given type and the result of the function will be asserted. +void createCallToCachedFunction( + ImplicitLocOpBuilder &b, const std::string &name, ArrayRef args, + std::optional assertInfo, ManglingArgs specializationArgs, + std::function buildBody) { + ModuleOp module = b.getInsertionPoint()->getParentOfType(); + int numWarps = ttg::lookupNumWarps(b.getInsertionPoint()->getParentRegion()); + SmallVector argTypes = llvm::to_vector( + llvm::map_range(args, [](Value v) { return v.getType(); })); + Type assertType = assertInfo ? assertInfo->type : nullptr; + triton::FuncOp func = + getOrCreateFunction(module, name, argTypes, specializationArgs, numWarps, + assertType, buildBody); + SmallVector resultTypes = {}; + if (assertInfo) { + resultTypes.push_back(assertInfo->type); + } + auto callOp = triton::CallOp::create(b, func.getName(), resultTypes, args); + if (assertInfo) { + Value result = callOp->getResult(0); + StringRef message = b.getStringAttr(assertInfo->message); + tti::ExperimentalAssertInThreadOp::create(b, result, message, false); + } +} + +Value createBufferDescriptor(ImplicitLocOpBuilder &b, Value offsetI32, + Value lengthI32) { + auto i64Type = b.getI64Type(); + Value offsetI64 = arith::ExtUIOp::create(b, i64Type, offsetI32); + Value lengthI64 = arith::ExtUIOp::create(b, i64Type, lengthI32); + Value shiftAmount = arith::ConstantIntOp::create(b, 32, 64); + Value lengthShifted = arith::ShLIOp::create(b, lengthI64, shiftAmount); + return arith::OrIOp::create(b, lengthShifted, offsetI64); +} + +uint32_t getMemDescLength(Value buf) { + auto memDescType = cast(buf.getType()); + if (isa(memDescType.getEncoding())) { + unsigned elSize = memDescType.getElementType().getIntOrFloatBitWidth() / 8; + return static_cast(product(memDescType.getShape()) * elSize); + } + if (isa(memDescType.getMemorySpace())) { + return ttng::getTmemAllocSizes(memDescType).numCols; + } + llvm_unreachable("Unsupported memory space for memdesc"); +} + +std::tuple createIfBlock(ImplicitLocOpBuilder &b, + Value cnd) { + // #prevBlock + // if (condition) { + // #ifBlock + // } + // #thenBlock + Block *prevBlock = b.getInsertionBlock(); + Block::iterator insertPoint = b.getInsertionPoint(); + Block *ifBlock = prevBlock->splitBlock(insertPoint); + + // Split a block after the call. + Block *thenBlock = ifBlock->splitBlock(ifBlock->begin()); + b.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(b, thenBlock); + b.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(b, cnd, ifBlock, ValueRange{}, thenBlock, + ValueRange{}); + b.setInsertionPointToStart(thenBlock); + + return {prevBlock, ifBlock, thenBlock}; +} + +Value convertAndBroadcast(ImplicitLocOpBuilder &b, Value tensor, int dim, + RankedTensorType dstType) { + auto loc = b.getLoc(); + ArrayRef shape = dstType.getShape(); + auto tensorType = cast(tensor.getType()); + auto encoding = cast(dstType.getEncoding()); + RankedTensorType resultType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + auto slicedEncoding = + ttg::SliceEncodingAttr::get(b.getContext(), dim, encoding); + tensor = ttg::ConvertLayoutOp::create( + b, tensorType.cloneWithEncoding(slicedEncoding), tensor); + tensor = tti::expandOuterSlicedDim(b, loc, tensor); + tensor = triton::BroadcastOp::create(b, resultType, tensor); + return tensor; +} + +Value createConvertLayout(ImplicitLocOpBuilder &b, Value tensor, + Attribute encoding) { + auto tensorType = cast(tensor.getType()); + auto dstType = tensorType.cloneWithEncoding(encoding); + return ttg::ConvertLayoutOp::create(b, dstType, tensor); +} + +Value expandAliases(ImplicitLocOpBuilder &b, Value bufferMask, + Value aliasMatrix, RankedTensorType aliasMatrixType) { + assert(aliasMatrixType.getRank() == 2 && + "Alias matrix expected to be rank-2"); + auto bufferMaskType = cast(bufferMask.getType()); + Value bufMaskMatrix = + convertAndBroadcast(b, bufferMask, /*dim=*/1, aliasMatrixType); + Value aliasingMask = arith::AndIOp::create(b, aliasMatrix, bufMaskMatrix); + Value aliasVector = createBitwiseOrReduce(b, aliasingMask, /*axis=*/0); + return createConvertLayout(b, aliasVector, bufferMaskType.getEncoding()); +} + +Value createOneHot(ImplicitLocOpBuilder &b, int size, int index, + Attribute encoding) { + auto loc = b.getLoc(); + auto type = RankedTensorType::get({size}, b.getI32Type(), encoding); + Value arange = + triton::MakeRangeOp::create(b, type, /*start=*/0, /*end=*/size); + Value indexTensor = + tti::createConstIntTensor(b, loc, index, type, /*isSigned=*/false); + return arith::CmpIOp::create(b, arith::CmpIPredicate::eq, arange, + indexTensor); +} + +Value createColumnMask(ImplicitLocOpBuilder &b, int column, + RankedTensorType tensorType) { + auto encoding = cast(tensorType.getEncoding()); + auto columnEncoding = tti::getSingleDimSliceEncoding(encoding, /*dim=*/1); + Value oneHot = + createOneHot(b, tensorType.getShape()[1], column, columnEncoding); + return convertAndBroadcast(b, oneHot, /*dim=*/0, tensorType); +} + +Value createMultiColumnMask(ImplicitLocOpBuilder &b, uint64_t columnMask, + RankedTensorType tensorType) { + auto loc = b.getLoc(); + auto i1TensorType = + cast(tensorType.cloneWith(std::nullopt, b.getI1Type())); + Value maskTensor = tti::createConstIntTensor(b, loc, 0, i1TensorType); + for (int i = 0; i < 64; ++i) { + if (columnMask & (1ULL << i)) { + Value columnMaskTensor = createColumnMask(b, i, tensorType); + maskTensor = arith::OrIOp::create(b, maskTensor, columnMaskTensor); + } + } + return maskTensor; +} + +Value adjustIntegerWidth(ImplicitLocOpBuilder &b, Value value, + IntegerType targetType) { + auto srcType = cast(value.getType()); + if (srcType.getWidth() == targetType.getWidth()) + return value; + if (srcType.getWidth() < targetType.getWidth()) + return arith::ExtUIOp::create(b, targetType, value); + return arith::TruncIOp::create(b, targetType, value); +} + +Value createThreadColumnMask(ImplicitLocOpBuilder &b, Value threadMask, + RankedTensorType tensorType) { + auto loc = b.getLoc(); + auto encoding = cast(tensorType.getEncoding()); + auto sliceEncoding = tti::getSingleDimSliceEncoding(encoding, /*dim=*/1); + int columns = tensorType.getShape()[1]; + + RankedTensorType rangeType = + RankedTensorType::get({columns}, b.getI32Type(), sliceEncoding); + Value range = triton::MakeRangeOp::create(b, rangeType, 0, columns); + + auto elemType = cast(tensorType.getElementType()); + RankedTensorType rangeElemType = + RankedTensorType::get({columns}, elemType, sliceEncoding); + Value rangeElem = range; + if (elemType.getWidth() != 32) + rangeElem = arith::ExtUIOp::create(b, rangeElemType, range); + + Value indices = convertAndBroadcast(b, rangeElem, /*dim=*/0, tensorType); + + Value threadMaskElem = adjustIntegerWidth(b, threadMask, elemType); + Value maskTensor = triton::SplatOp::create(b, tensorType, threadMaskElem); + + Value shifted = arith::ShRUIOp::create(b, maskTensor, indices); + Value one = tti::createConstIntTensor(b, loc, 1, tensorType); + Value bits = arith::AndIOp::create(b, shifted, one); + Value zero = tti::createConstIntTensor(b, loc, 0, tensorType); + return arith::CmpIOp::create(b, arith::CmpIPredicate::ne, bits, zero); +} + +Value createColumnMask(ImplicitLocOpBuilder &b, Value column, + RankedTensorType tensorType) { + auto loc = b.getLoc(); + auto encoding = cast(tensorType.getEncoding()); + auto sliceEncoding = tti::getSingleDimSliceEncoding(encoding, /*dim=*/1); + auto colType = RankedTensorType::get({tensorType.getShape()[1]}, + b.getI32Type(), sliceEncoding); + Value range = triton::MakeRangeOp::create(b, colType, /*start=*/0, + /*end=*/tensorType.getShape()[1]); + Value columnTensor = triton::SplatOp::create(b, colType, column); + Value mask1D = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, range, columnTensor); + return convertAndBroadcast(b, mask1D, /*dim=*/0, tensorType); +} + +} // namespace + +void FunctionBuilder::createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, + int thread, Value phase, Value pred, + Operation *insertPoint) { + + if (auxData.barriers.empty() || auxData.waiting.empty()) { + return; + } + if (!pred) { + pred = arith::ConstantIntOp::create(b, 1, 1); + } + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value waitingVal = auxData.waiting.at(insertPoint).value; + auto waitingType = + cast(auxData.waiting.at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, threadVal, phase, + pred, barriersVal, waitingVal}; + createCallToCachedFunction( + b, "set_waiting", args, + /*assertInfo=*/std::nullopt, {barriersType, waitingType}, + [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value baseThread = entryBlock->getArgument(2); + Value phase = entryBlock->getArgument(3); + Value pred = entryBlock->getArgument(4); + + Value barriers = entryBlock->getArgument(5); + Value waitingPtr = entryBlock->getArgument(6); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value waiting = tti::createLoadScratchMemory(fb, fb.getLoc(), + waitingPtr, waitingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + + Value bitsPerThread = + arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); + Value flagBit = + arith::ConstantIntOp::create(fb, WaitingBits::flagBit, 32); + Value phaseBit = + arith::ConstantIntOp::create(fb, WaitingBits::phaseBit, 32); + Value one = arith::ConstantIntOp::create(fb, 1, 32); + Value minusOne = arith::ConstantIntOp::create(fb, -1, 32); + + Value baseTimesBits = + arith::MulIOp::create(fb, baseThread, bitsPerThread); + Value flagShift = arith::AddIOp::create(fb, baseTimesBits, flagBit); + Value phaseShift = arith::AddIOp::create(fb, baseTimesBits, phaseBit); + + Value flagMaskScalar = arith::ShLIOp::create(fb, one, flagShift); + Value phaseMaskScalar = arith::ShLIOp::create(fb, one, phaseShift); + Value combinedMask = + arith::OrIOp::create(fb, flagMaskScalar, phaseMaskScalar); + Value clearMaskScalar = + arith::XOrIOp::create(fb, combinedMask, minusOne); + + Value flagMaskTensor = + triton::SplatOp::create(fb, waitingType, flagMaskScalar); + Value clearMaskTensor = + triton::SplatOp::create(fb, waitingType, clearMaskScalar); + Value phaseShiftTensor = + triton::SplatOp::create(fb, waitingType, phaseShift); + + Value clearedWaiting = + arith::AndIOp::create(fb, waiting, clearMaskTensor); + Value withFlag = + arith::OrIOp::create(fb, clearedWaiting, flagMaskTensor); + + Value phaseScalar = arith::AndIOp::create(fb, phase, one); + Value phaseTensor = + triton::SplatOp::create(fb, waitingType, phaseScalar); + Value phaseBits = + arith::ShLIOp::create(fb, phaseTensor, phaseShiftTensor); + Value pendingWaiting = arith::OrIOp::create(fb, withFlag, phaseBits); + + auto condType = cast(barriersEqBar.getType()); + Value predTensor = triton::SplatOp::create(fb, condType, pred); + Value cond = arith::AndIOp::create(fb, barriersEqBar, predTensor); + + Value newWaiting = + arith::SelectOp::create(fb, cond, pendingWaiting, waiting); + tti::createStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, + waitingType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearWaitingCall(ImplicitLocOpBuilder &b, + Value mbar, int thread, Value pred, + Operation *insertPoint) { + if (auxData.barriers.empty() || auxData.waiting.empty()) { + return; + } + if (!pred) { + pred = arith::ConstantIntOp::create(b, 1, 1); + } + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value waitingVal = auxData.waiting.at(insertPoint).value; + auto waitingType = + cast(auxData.waiting.at(insertPoint).type); + + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, threadVal, + pred, barriersVal, waitingVal}; + createCallToCachedFunction( + b, "clear_waiting", args, + /*assertInfo=*/std::nullopt, {barriersType, waitingType}, + [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value baseThread = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); + + Value barriers = entryBlock->getArgument(4); + Value waitingPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value waiting = tti::createLoadScratchMemory(fb, fb.getLoc(), + waitingPtr, waitingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + + Value bitsPerThread = + arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); + Value flagBit = + arith::ConstantIntOp::create(fb, WaitingBits::flagBit, 32); + Value phaseBit = + arith::ConstantIntOp::create(fb, WaitingBits::phaseBit, 32); + Value one = arith::ConstantIntOp::create(fb, 1, 32); + Value minusOne = arith::ConstantIntOp::create(fb, -1, 32); + + Value baseTimesBits = + arith::MulIOp::create(fb, baseThread, bitsPerThread); + Value flagShift = arith::AddIOp::create(fb, baseTimesBits, flagBit); + Value phaseShift = arith::AddIOp::create(fb, baseTimesBits, phaseBit); + + Value flagMaskScalar = arith::ShLIOp::create(fb, one, flagShift); + Value phaseMaskScalar = arith::ShLIOp::create(fb, one, phaseShift); + Value combinedMask = + arith::OrIOp::create(fb, flagMaskScalar, phaseMaskScalar); + Value clearMaskScalar = + arith::XOrIOp::create(fb, combinedMask, minusOne); + + Value clearMaskTensor = + triton::SplatOp::create(fb, waitingType, clearMaskScalar); + Value clearedWaiting = + arith::AndIOp::create(fb, waiting, clearMaskTensor); + + Value newWaiting = + arith::SelectOp::create(fb, barriersEqBar, clearedWaiting, waiting); + + tti::createStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, + waitingType); + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, + int activeMask, + Value pred, + Operation *insertPoint) { + if (auxData.waiting.empty() || auxData.barrierStates.empty()) { + return; + } + if (!pred) { + pred = arith::ConstantIntOp::create(b, 1, 1); + } + int64_t expandedActiveMask = expandActiveMask(activeMask); + Value expandedActiveMaskVal = + arith::ConstantIntOp::create(b, expandedActiveMask, 32); + Value waitingVal = auxData.waiting.at(insertPoint).value; + auto waitingType = + cast(auxData.waiting.at(insertPoint).type); + Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; + auto barrierStatesType = + cast(auxData.barrierStates.at(insertPoint).type); + SmallVector args = {expandedActiveMaskVal, pred, waitingVal, + barrierStatesVal}; + AssertInfo assertInfo{ + "Deadlock detected: all active threads are waiting on mbarriers", + b.getI1Type()}; + createCallToCachedFunction( + b, "check_all_active_waiting", args, assertInfo, + {waitingType, barrierStatesType}, + [waitingType, barrierStatesType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value expandedActiveMaskVal = entryBlock->getArgument(0); + Value pred = entryBlock->getArgument(1); + + Value waitingPtr = entryBlock->getArgument(2); + Value barrierStatesPtr = entryBlock->getArgument(3); + + Value waiting = tti::createLoadScratchMemory(fb, fb.getLoc(), + waitingPtr, waitingType); + Value barrierStates = tti::createLoadScratchMemory( + fb, fb.getLoc(), barrierStatesPtr, barrierStatesType); + + Value flagMaskTensor = tti::createConstIntTensor( + fb, fb.getLoc(), WaitingBits::flagMask, waitingType); + Value phaseMaskTensor = tti::createConstIntTensor( + fb, fb.getLoc(), WaitingBits::phaseMask, waitingType); + + Value flags = arith::AndIOp::create(fb, waiting, flagMaskTensor); + Value phases = arith::AndIOp::create(fb, waiting, phaseMaskTensor); + Value shiftOneTensor = + tti::createConstIntTensor(fb, fb.getLoc(), 1, waitingType); + Value phasesAligned = + arith::ShRUIOp::create(fb, phases, shiftOneTensor); + + Value phasesComplement = + arith::XOrIOp::create(fb, phasesAligned, flagMaskTensor); + Value waitingPhase0 = + arith::AndIOp::create(fb, flags, phasesComplement); + Value waitingPhase1 = arith::AndIOp::create(fb, flags, phasesAligned); + + Value oneState = + tti::createConstIntTensor(fb, fb.getLoc(), 1, barrierStatesType); + Value barrierPhase = arith::AndIOp::create(fb, barrierStates, oneState); + Value phaseIsOne = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + barrierPhase, oneState); + + Value effectiveWaiting = arith::SelectOp::create( + fb, phaseIsOne, waitingPhase1, waitingPhase0); + Value waitingOr = + createBitwiseOrReduce(fb, effectiveWaiting, /*axis=*/0); + + auto waitingOrTy = waitingOr.getType(); + Value waitingMasked = + arith::AndIOp::create(fb, waitingOr, expandedActiveMaskVal); + Value eq = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + waitingMasked, expandedActiveMaskVal); + + Value vTrue = arith::ConstantOp::create( + fb, eq.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); + Value ok = arith::XOrIOp::create(fb, eq, vTrue); + Value predicatedOk = arith::SelectOp::create(fb, pred, ok, vTrue); + triton::ReturnOp::create(fb, predicatedOk); + }); +} + +void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, + Value mbar, int count, + Operation *insertPoint) { + + if (auxData.barriers.empty() || auxData.barrierStates.empty()) { + return; + } + Value countVal = arith::ConstantIntOp::create(b, count, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; + auto barrierStatesType = + cast(auxData.barrierStates.at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, barriersVal, + barrierStatesVal}; + createCallToCachedFunction( + b, "init_barrier_state", args, + /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, + [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); + + Value barriers = entryBlock->getArgument(3); + Value statesPtr = entryBlock->getArgument(4); + + Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, + barrierStatesType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); + + Value countMask = + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); + Value maskedCount = arith::AndIOp::create(fb, count, countMask); + Value countTensor = + triton::SplatOp::create(fb, barrierStatesType, maskedCount); + + Value shiftOneTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::initCountLsb, barrierStatesType); + Value shiftNineTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); + + Value initField = + arith::ShLIOp::create(fb, countTensor, shiftOneTensor); + Value currentField = + arith::ShLIOp::create(fb, countTensor, shiftNineTensor); + Value newState = arith::OrIOp::create(fb, initField, currentField); + + Value updated = arith::SelectOp::create(fb, mask, newState, states); + tti::createStoreScratchMemory(fb, fb.getLoc(), statesPtr, updated, + barrierStatesType); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, + Value mbar, int count, + Value pred, + Operation *insertPoint) { + + if (auxData.barriers.empty() || auxData.barrierStates.empty()) { + return; + } + if (!pred) { + pred = arith::ConstantIntOp::create(b, 1, 1); + } + Value countVal = arith::ConstantIntOp::create(b, count, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; + auto barrierStatesType = + cast(auxData.barrierStates.at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, + pred, barriersVal, barrierStatesVal}; + AssertInfo assertInfo{ + "Barrier arrive underflow: current count would become negative", + barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; + createCallToCachedFunction( + b, "verify_barrier_arrive", args, assertInfo, + {barriersType, barrierStatesType}, + [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); + + Value barriers = entryBlock->getArgument(4); + Value statesPtr = entryBlock->getArgument(5); + + Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, + barrierStatesType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); + + Value zero32 = + tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); + Value maskFF = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::countMask, barrierStatesType); + Value shiftNineTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); + + Value currentCount = + arith::ShRUIOp::create(fb, states, shiftNineTensor); + currentCount = arith::AndIOp::create(fb, currentCount, maskFF); + + Value countMask = + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); + Value maskedCount = arith::AndIOp::create(fb, count, countMask); + Value arriveCount = + triton::SplatOp::create(fb, barrierStatesType, maskedCount); + + Value newCurrent = arith::SubIOp::create(fb, currentCount, arriveCount); + Value newCurrentMasked = + arith::SelectOp::create(fb, mask, newCurrent, zero32); + Value nonNegative = arith::CmpIOp::create(fb, arith::CmpIPredicate::sge, + newCurrentMasked, zero32); + Value vTrue = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(nonNegative.getType())); + Value predicatedNonNegative = + arith::SelectOp::create(fb, pred, nonNegative, vTrue); + + triton::ReturnOp::create(fb, predicatedNonNegative); + }); +} + +void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, + Value mbar, int count, + Value pred, + Operation *insertPoint) { + + if (auxData.barriers.empty() || auxData.barrierStates.empty()) { + return; + } + if (!pred) { + pred = arith::ConstantIntOp::create(b, 1, 1); + } + Value countVal = arith::ConstantIntOp::create(b, count, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value barrierStatesVal = auxData.barrierStates.at(insertPoint).value; + auto barrierStatesType = + cast(auxData.barrierStates.at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, countVal, + pred, barriersVal, barrierStatesVal}; + createCallToCachedFunction( + b, "update_barrier_state", args, + /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, + [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value count = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); + + Value barriers = entryBlock->getArgument(4); + Value statesPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, + barrierStatesType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); + + Value zero32 = + tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); + Value one32 = + tti::createConstIntTensor(fb, fb.getLoc(), 1, barrierStatesType); + Value maskFF = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::countMask, barrierStatesType); + Value shiftOneTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::initCountLsb, barrierStatesType); + Value shiftNineTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); + + Value phase = arith::AndIOp::create(fb, states, one32); + Value initCount = arith::ShRUIOp::create(fb, states, shiftOneTensor); + initCount = arith::AndIOp::create(fb, initCount, maskFF); + Value currentCount = + arith::ShRUIOp::create(fb, states, shiftNineTensor); + currentCount = arith::AndIOp::create(fb, currentCount, maskFF); + + Value countMask = + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); + Value maskedCount = arith::AndIOp::create(fb, count, countMask); + Value arriveCount = + triton::SplatOp::create(fb, barrierStatesType, maskedCount); + + Value newCurrent = arith::SubIOp::create(fb, currentCount, arriveCount); + Value newCurrentMasked = + arith::SelectOp::create(fb, mask, newCurrent, currentCount); + + Value zeroCond = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + newCurrentMasked, zero32); + zeroCond = arith::AndIOp::create(fb, zeroCond, mask); + Value zeroCondI32 = + arith::ExtUIOp::create(fb, barrierStatesType, zeroCond); + Value newPhase = arith::XOrIOp::create(fb, phase, zeroCondI32); + Value newCurrentValue = + arith::SelectOp::create(fb, zeroCond, initCount, newCurrentMasked); + + Value initField = arith::ShLIOp::create(fb, initCount, shiftOneTensor); + Value currentField = + arith::ShLIOp::create(fb, newCurrentValue, shiftNineTensor); + Value newState = arith::OrIOp::create(fb, newPhase, initField); + newState = arith::OrIOp::create(fb, newState, currentField); + + Value updated = arith::SelectOp::create(fb, mask, newState, states); + tti::createStoreScratchMemory(fb, fb.getLoc(), statesPtr, updated, + barrierStatesType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, + Value buf, uint32_t length, + uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.buffers[(int)memType].empty() || + auxData.writeVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadMaskVal = arith::ConstantIntOp::create(b, threadMask, 64); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value writeVisibilityVal = + auxData.writeVisibility[(int)memType].at(insertPoint).value; + auto writeVisibilityType = cast( + auxData.writeVisibility[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadMaskVal, buffersVal, writeVisibilityVal}; + createCallToCachedFunction( + b, "set_write_visibility", args, + /*assertInfo=*/std::nullopt, + {buffersType, writeVisibilityType, (int)memType}, + [buffersType, writeVisibilityType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + auto elemType = cast(writeVisibilityType.getElementType()); + Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); + Value threadMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, threadMaskElem); + Value newVisibility = arith::SelectOp::create( + fb, buffersEqBuf, threadMaskTensor, writeVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + newVisibility, writeVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createSetReadVisibilityCall(ImplicitLocOpBuilder &b, + Value buf, uint32_t length, + uint64_t threadMask, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.buffers[(int)memType].empty() || + auxData.readVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadMaskVal = arith::ConstantIntOp::create(b, threadMask, 64); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value readVisibilityVal = + auxData.readVisibility[(int)memType].at(insertPoint).value; + auto readVisibilityType = cast( + auxData.readVisibility[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, + threadMaskVal, buffersVal, readVisibilityVal}; + createCallToCachedFunction( + b, "set_read_visibility", args, + /*assertInfo=*/std::nullopt, + {buffersType, readVisibilityType, (int)memType}, + [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, + readVisibilityType); + auto elemType = cast(readVisibilityType.getElementType()); + Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); + Value threadBit = + triton::SplatOp::create(fb, readVisibilityType, threadMaskElem); + Value threadColumnMask = + createThreadColumnMask(fb, threadMaskVal, readVisibilityType); + Value readVisibilityOrThreadBit = + arith::OrIOp::create(fb, readVisibility, threadBit); + Value bufAndThread = + arith::AndIOp::create(fb, buffersEqBuf, threadColumnMask); + Value newVisibility = arith::SelectOp::create( + fb, bufAndThread, readVisibilityOrThreadBit, readVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, + Value buf, uint32_t length, + Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.writeTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value writeTrackingVal = + auxData.writeTracking[(int)memType].at(insertPoint).value; + auto writeTrackingType = cast( + auxData.writeTracking[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + writeTrackingVal}; + createCallToCachedFunction( + b, "clear_write_tracking", args, + /*assertInfo=*/std::nullopt, + {buffersType, writeTrackingType, (int)memType}, + [buffersType, writeTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value writeTrackingPtr = entryBlock->getArgument(4); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, writeTrackingType); + Value zero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeTrackingType); + Value newTracking = + arith::SelectOp::create(fb, buffersEqBuf, zero, writeTracking); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + newTracking, writeTrackingType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, + Value buf, uint32_t length, + Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.readVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value readVisibilityVal = + auxData.readVisibility[(int)memType].at(insertPoint).value; + auto readVisibilityType = cast( + auxData.readVisibility[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + readVisibilityVal}; + createCallToCachedFunction( + b, "clear_read_visibility", args, + /*assertInfo=*/std::nullopt, + {buffersType, readVisibilityType, (int)memType}, + [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value readVisibilityPtr = entryBlock->getArgument(4); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, + readVisibilityType); + Value zero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); + Value newVisibility = + arith::SelectOp::create(fb, buffersEqBuf, zero, readVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, + Value buf, uint32_t length, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.buffers[(int)memType].empty() || + auxData.readTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value readTrackingVal = + auxData.readTracking[(int)memType].at(insertPoint).value; + auto readTrackingType = cast( + auxData.readTracking[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, pred, buffersVal, + readTrackingVal}; + createCallToCachedFunction( + b, "clear_read_tracking", args, + /*assertInfo=*/std::nullopt, + {buffersType, readTrackingType, (int)memType}, + [buffersType, readTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value buffers = entryBlock->getArgument(3); + Value readTrackingPtr = entryBlock->getArgument(4); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), readTrackingPtr, readTrackingType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readTrackingType); + Value zero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readTrackingType); + Value newTracking = + arith::SelectOp::create(fb, buffersEqBuf, zero, readTracking); + tti::createStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, + newTracking, readTrackingType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, + Value mbar, int thread, + Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.barriers.empty() || + auxData.writeVisibility[(int)memType].empty() || + auxData.writeTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value writeVisibilityVal = + auxData.writeVisibility[(int)memType].at(insertPoint).value; + auto writeVisibilityType = cast( + auxData.writeVisibility[(int)memType].at(insertPoint).type); + Value writeTrackingVal = + auxData.writeTracking[(int)memType].at(insertPoint).value; + auto writeTrackingType = cast( + auxData.writeTracking[(int)memType].at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, writeVisibilityVal, + writeTrackingVal}; + createCallToCachedFunction( + b, "track_visible_writes", args, + /*assertInfo=*/std::nullopt, + {barriersType, writeVisibilityType, writeTrackingType, (int)memType}, + [barriersType, writeVisibilityType, + writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value writeTrackingPtr = entryBlock->getArgument(6); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + Value writeTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, + writeTrackingType); + Value threadI64 = + arith::ExtUIOp::create(fb, fb.getI64Type(), threadVal); + Value one64 = arith::ConstantIntOp::create(fb, 1, 64); + Value threadBitScalar = arith::ShLIOp::create(fb, one64, threadI64); + Value threadBit = + triton::SplatOp::create(fb, writeVisibilityType, threadBitScalar); + Value visibleWrites = + arith::AndIOp::create(fb, writeVisibility, threadBit); + visibleWrites = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + visibleWrites, threadBit); + visibleWrites = convertAndBroadcast(fb, visibleWrites, /*dim=*/1, + writeTrackingType); + Value barAndVisible = + arith::AndIOp::create(fb, barriersEqBar, visibleWrites); + Value writeTrackingOne = + tti::createConstIntTensor(fb, fb.getLoc(), 1, writeTrackingType); + Value newTracking = arith::SelectOp::create( + fb, barAndVisible, writeTrackingOne, writeTracking); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + newTracking, writeTrackingType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, + Value mbar, int thread, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.barriers.empty() || + auxData.readVisibility[(int)memType].empty() || + auxData.readTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value readVisibilityVal = + auxData.readVisibility[(int)memType].at(insertPoint).value; + auto readVisibilityType = cast( + auxData.readVisibility[(int)memType].at(insertPoint).type); + Value readTrackingVal = + auxData.readTracking[(int)memType].at(insertPoint).value; + auto readTrackingType = cast( + auxData.readTracking[(int)memType].at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, readVisibilityVal, + readTrackingVal}; + createCallToCachedFunction( + b, "track_visible_reads", args, + /*assertInfo=*/std::nullopt, + {barriersType, readVisibilityType, readTrackingType, (int)memType}, + [barriersType, readVisibilityType, + readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value readTrackingPtr = entryBlock->getArgument(6); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value readTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), readTrackingPtr, readTrackingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + barriersEqBar = + convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, readTrackingType); + Value threadColumnMask = + createColumnMask(fb, threadVal, readVisibilityType); + Value readVisibilityZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); + Value visibleReads = arith::SelectOp::create( + fb, threadColumnMask, readVisibility, readVisibilityZero); + visibleReads = createBitwiseOrReduce(fb, visibleReads, /*axis=*/1); + visibleReads = + convertAndBroadcast(fb, visibleReads, /*dim=*/1, readTrackingType); + Value readTrackingOrVisible = + arith::OrIOp::create(fb, readTracking, visibleReads); + Value newTracking = arith::SelectOp::create( + fb, barriersEqBar, readTrackingOrVisible, readTracking); + tti::createStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, + newTracking, readTrackingType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createTransferVisibleWritesCall( + ImplicitLocOpBuilder &b, Value mbar, uint64_t threadMask, Value pred, + MemType memType, Operation *insertPoint) { + + if (auxData.barriers.empty() || + auxData.writeVisibility[(int)memType].empty() || + auxData.writeTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadMaskVal = arith::ConstantIntOp::create(b, threadMask, 64); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value writeVisibilityVal = + auxData.writeVisibility[(int)memType].at(insertPoint).value; + auto writeVisibilityType = cast( + auxData.writeVisibility[(int)memType].at(insertPoint).type); + Value writeTrackingVal = + auxData.writeTracking[(int)memType].at(insertPoint).value; + auto writeTrackingType = cast( + auxData.writeTracking[(int)memType].at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadMaskVal, barriersVal, writeVisibilityVal, + writeTrackingVal}; + createCallToCachedFunction( + b, "transfer_visible_writes", args, + /*assertInfo=*/std::nullopt, + {barriersType, writeVisibilityType, writeTrackingType, (int)memType}, + [barriersType, writeVisibilityType, + writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value writeTrackingPtr = entryBlock->getArgument(6); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + Value writeTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + barriersEqBar = convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, + writeTrackingType); + Value zeroTracking = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeTrackingType); + Value trackingBuffers = arith::SelectOp::create( + fb, barriersEqBar, writeTracking, zeroTracking); + trackingBuffers = + createBitwiseOrReduce(fb, trackingBuffers, /*axis=*/1); + trackingBuffers = createConvertLayout( + fb, trackingBuffers, writeVisibilityType.getEncoding()); + auto trackingBuffersType = + cast(trackingBuffers.getType()); + Value trackingBuffersOne = + tti::createConstIntTensor(fb, fb.getLoc(), 1, trackingBuffersType); + trackingBuffers = arith::CmpIOp::create( + fb, arith::CmpIPredicate::eq, trackingBuffers, trackingBuffersOne); + auto elemType = cast(writeVisibilityType.getElementType()); + Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); + Value threadMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, threadMaskElem); + Value zeroVisibility = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); + Value trackingThreadBit = arith::SelectOp::create( + fb, trackingBuffers, threadMaskTensor, zeroVisibility); + Value newVisibility = + arith::OrIOp::create(fb, writeVisibility, trackingThreadBit); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + newVisibility, writeVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createTransferVisibleReadsCall( + ImplicitLocOpBuilder &b, Value mbar, uint64_t threadMask, Value pred, + MemType memType, Operation *insertPoint) { + + if (auxData.barriers.empty() || + auxData.readVisibility[(int)memType].empty() || + auxData.readTracking[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadMaskVal = arith::ConstantIntOp::create(b, threadMask, 64); + Value barriersVal = auxData.barriers.at(insertPoint).value; + auto barriersType = + cast(auxData.barriers.at(insertPoint).type); + Value readVisibilityVal = + auxData.readVisibility[(int)memType].at(insertPoint).value; + auto readVisibilityType = cast( + auxData.readVisibility[(int)memType].at(insertPoint).type); + Value readTrackingVal = + auxData.readTracking[(int)memType].at(insertPoint).value; + auto readTrackingType = cast( + auxData.readTracking[(int)memType].at(insertPoint).type); + uint32_t length = getMemDescLength(mbar); + Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {mbarOffset, lengthVal, pred, + threadMaskVal, barriersVal, readVisibilityVal, + readTrackingVal}; + createCallToCachedFunction( + b, "transfer_visible_reads", args, + /*assertInfo=*/std::nullopt, + {barriersType, readVisibilityType, readTrackingType, (int)memType}, + [barriersType, readVisibilityType, + readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value mbarOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadMaskVal = entryBlock->getArgument(3); + Value barriers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value readTrackingPtr = entryBlock->getArgument(6); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value readTracking = tti::createLoadScratchMemory( + fb, fb.getLoc(), readTrackingPtr, readTrackingType); + Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); + Value barriersEqBar = + createCmpIntTensorScalar(fb, barriers, descriptor); + barriersEqBar = + convertAndBroadcast(fb, barriersEqBar, /*dim=*/0, readTrackingType); + Value readTrackingZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readTrackingType); + Value trackingBar = arith::SelectOp::create( + fb, barriersEqBar, readTracking, readTrackingZero); + trackingBar = createBitwiseOrReduce(fb, trackingBar, /*axis=*/1); + trackingBar = + convertAndBroadcast(fb, trackingBar, /*dim=*/1, readVisibilityType); + Value readVisibilityOrTracking = + arith::OrIOp::create(fb, readVisibility, trackingBar); + Value threadColumnMask = + createThreadColumnMask(fb, threadMaskVal, readVisibilityType); + Value newVisibility = arith::SelectOp::create( + fb, threadColumnMask, readVisibilityOrTracking, readVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createVerifyWriteVisibilityCall( + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef operandName, Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.writeVisibility[(int)memType].empty() || + (auxData.hasNonTrivialAliasing[(int)memType] && + auxData.aliasMatrices[(int)memType].empty())) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value writeVisibilityVal = + auxData.writeVisibility[(int)memType].at(insertPoint).value; + auto writeVisibilityType = cast( + auxData.writeVisibility[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + std::string message = "Buffer being accessed has outstanding writes."; + if (!operandName.empty()) + message += " Operand: " + operandName.str(); + AssertInfo assertInfo{message, + buffersType.cloneWith(std::nullopt, b.getI1Type())}; + Type aliasMatrixTypeBase; + auto buildVerifyWriteBody = [&](bool useAlias) { + return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + Value aliasMatrix = useAlias ? entryBlock->getArgument(6) : Value(); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + if (useAlias) { + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, + cast(aliasMatrixTypeBase)); + } + Value writeVisibilityZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); + Value bufVisibility = arith::SelectOp::create( + fb, buffersEqBuf, writeVisibility, writeVisibilityZero); + Value noOneIsWriting = arith::CmpIOp::create( + fb, arith::CmpIPredicate::eq, bufVisibility, writeVisibilityZero); + Value threadI64 = arith::ExtUIOp::create(fb, fb.getI64Type(), threadVal); + Value threadMask = + triton::SplatOp::create(fb, writeVisibilityType, threadI64); + Value buffersEqBufExt = + arith::ExtUIOp::create(fb, writeVisibilityType, buffersEqBuf); + Value bufferThreadBit = + arith::ShLIOp::create(fb, buffersEqBufExt, threadMask); + Value bufferHasVisibility = + arith::AndIOp::create(fb, bufVisibility, bufferThreadBit); + bufferHasVisibility = arith::CmpIOp::create( + fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit); + Value writeVisible = + arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); + + Value vTrue = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(writeVisible.getType())); + Value predicatedWriteVisible = + arith::SelectOp::create(fb, pred, writeVisible, vTrue); + triton::ReturnOp::create(fb, predicatedWriteVisible); + }; + }; + if (auxData.hasNonTrivialAliasing[(int)memType]) { + Value aliasMatrixVal = + auxData.aliasMatrices[(int)memType].at(insertPoint).value; + aliasMatrixTypeBase = + auxData.aliasMatrices[(int)memType].at(insertPoint).type; + auto aliasMatrixType = cast(aliasMatrixTypeBase); + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, writeVisibilityVal, + aliasMatrixVal}; + createCallToCachedFunction( + b, "verify_write_visibility", args, assertInfo, + {buffersType, writeVisibilityType, aliasMatrixType, (int)memType}, + buildVerifyWriteBody(/*useAlias=*/true)); + } else { + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, writeVisibilityVal}; + createCallToCachedFunction(b, "verify_write_visibility_noalias", args, + assertInfo, + {buffersType, writeVisibilityType, (int)memType}, + buildVerifyWriteBody(/*useAlias=*/false)); + } +} + +void FunctionBuilder::createVerifyReadVisibilityCall( + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef operandName, Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.readVisibility[(int)memType].empty() || + (auxData.hasNonTrivialAliasing[(int)memType] && + auxData.aliasMatrices[(int)memType].empty())) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value buffersVal = auxData.buffers[(int)memType].at(insertPoint).value; + auto buffersType = cast( + auxData.buffers[(int)memType].at(insertPoint).type); + Value readVisibilityVal = + auxData.readVisibility[(int)memType].at(insertPoint).value; + auto readVisibilityType = cast( + auxData.readVisibility[(int)memType].at(insertPoint).type); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + std::string message = "Buffer being accessed has outstanding reads"; + if (!operandName.empty()) + message += ". Operand: " + operandName.str(); + AssertInfo assertInfo{message, + buffersType.cloneWith(std::nullopt, b.getI1Type())}; + Type aliasMatrixTypeBase; + auto buildVerifyReadBody = [&](bool useAlias) { + return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + Value aliasMatrix = useAlias ? entryBlock->getArgument(6) : Value(); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + if (useAlias) { + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, + cast(aliasMatrixTypeBase)); + } + buffersEqBuf = + convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, readVisibilityType); + Value readVisibilityZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); + Value bufVisibility = arith::SelectOp::create( + fb, buffersEqBuf, readVisibility, readVisibilityZero); + Value totalVisibility = + createBitwiseOrReduce(fb, bufVisibility, /*axis=*/1); + Value threadColumnMask = + createColumnMask(fb, threadVal, readVisibilityType); + Value bufThreadVisibility = arith::SelectOp::create( + fb, threadColumnMask, bufVisibility, readVisibilityZero); + bufThreadVisibility = + createBitwiseOrReduce(fb, bufThreadVisibility, /*axis=*/1); + Value threadAndTotalVisibility = + arith::AndIOp::create(fb, bufThreadVisibility, totalVisibility); + Value hasVisibility = + arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + threadAndTotalVisibility, totalVisibility); + Value vTrue = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(hasVisibility.getType())); + Value predicatedHasVisibility = + arith::SelectOp::create(fb, pred, hasVisibility, vTrue); + predicatedHasVisibility = createConvertLayout(fb, predicatedHasVisibility, + buffersType.getEncoding()); + triton::ReturnOp::create(fb, predicatedHasVisibility); + }; + }; + if (auxData.hasNonTrivialAliasing[(int)memType]) { + Value aliasMatrixVal = + auxData.aliasMatrices[(int)memType].at(insertPoint).value; + aliasMatrixTypeBase = + auxData.aliasMatrices[(int)memType].at(insertPoint).type; + auto aliasMatrixType = cast(aliasMatrixTypeBase); + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, readVisibilityVal, + aliasMatrixVal}; + createCallToCachedFunction( + b, "verify_read_visibility", args, assertInfo, + {buffersType, readVisibilityType, aliasMatrixType, (int)memType}, + buildVerifyReadBody(/*useAlias=*/true)); + } else { + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, readVisibilityVal}; + createCallToCachedFunction(b, "verify_read_visibility_noalias", args, + assertInfo, + {buffersType, readVisibilityType, (int)memType}, + buildVerifyReadBody(/*useAlias=*/false)); + } +} + +void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, + int sourceThread, + uint64_t destMask, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.writeVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + auto writeVis = auxData.writeVisibility[(int)memType].at(insertPoint); + auto writeVisibilityType = cast(writeVis.type); + Value sourceThreadVal = arith::ConstantIntOp::create(b, sourceThread, 32); + Value destMaskVal = arith::ConstantIntOp::create(b, destMask, 64); + SmallVector args = {sourceThreadVal, destMaskVal, pred, + writeVis.value}; + createCallToCachedFunction( + b, "copy_write_visibility", args, + /*assertInfo=*/std::nullopt, {writeVisibilityType, (int)memType}, + [writeVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value sourceThread = entryBlock->getArgument(0); + Value destMaskVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value writeVisibilityPtr = entryBlock->getArgument(3); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + auto elemType = cast(writeVisibilityType.getElementType()); + Value zeroTensor = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); + + constexpr uint64_t fullMask = + tti::THREADS_BITMASK_SIZE >= 64 + ? std::numeric_limits::max() + : ((1ull << tti::THREADS_BITMASK_SIZE) - 1); + Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64); + Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType); + Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType); + Value clearMaskElem = + arith::XOrIOp::create(fb, destMaskElem, fullMaskElem); + Value destMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, destMaskElem); + Value clearMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, clearMaskElem); + Value cleared = + arith::AndIOp::create(fb, writeVisibility, clearMaskTensor); + + Value sourceThreadElem = adjustIntegerWidth(fb, sourceThread, elemType); + Value oneScalar = arith::ConstantOp::create( + fb, elemType, fb.getIntegerAttr(elemType, 1)); + Value sourceMaskElem = + arith::ShLIOp::create(fb, oneScalar, sourceThreadElem); + Value sourceMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, sourceMaskElem); + Value sourceBits = + arith::AndIOp::create(fb, writeVisibility, sourceMaskTensor); + Value sourceIsSet = arith::CmpIOp::create(fb, arith::CmpIPredicate::ne, + sourceBits, zeroTensor); + Value replicated = arith::SelectOp::create(fb, sourceIsSet, + destMaskTensor, zeroTensor); + + Value updated = arith::OrIOp::create(fb, cleared, replicated); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + updated, writeVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, + int sourceThread, + uint64_t destMask, + Value pred, MemType memType, + Operation *insertPoint) { + + if (auxData.readVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + auto readVis = auxData.readVisibility[(int)memType].at(insertPoint); + auto readVisibilityType = cast(readVis.type); + Value sourceThreadVal = arith::ConstantIntOp::create(b, sourceThread, 32); + SmallVector args = {sourceThreadVal, + arith::ConstantIntOp::create(b, destMask, 64), + pred, readVis.value}; + createCallToCachedFunction( + b, "copy_read_visibility", args, + /*assertInfo=*/std::nullopt, {readVisibilityType, (int)memType}, + [readVisibilityType, destMask](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value sourceThread = entryBlock->getArgument(0); + /*Value destMaskVal = entryBlock->getArgument(1);*/ + Value pred = entryBlock->getArgument(2); + Value readVisibilityPtr = entryBlock->getArgument(3); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + Value zeroTensor = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); + Value destMaskTensor = + createMultiColumnMask(fb, destMask, readVisibilityType); + Value cleared = arith::SelectOp::create(fb, destMaskTensor, zeroTensor, + readVisibility); + + Value sourceColumnMask = + createColumnMask(fb, sourceThread, readVisibilityType); + Value sourceColumn = arith::SelectOp::create( + fb, sourceColumnMask, readVisibility, zeroTensor); + Value sourceVector = + createBitwiseOrReduce(fb, sourceColumn, /*axis=*/1); + Value broadcastRow = convertAndBroadcast(fb, sourceVector, /*dim=*/1, + readVisibilityType); + Value replicated = arith::SelectOp::create(fb, destMaskTensor, + broadcastRow, zeroTensor); + + Value updated = arith::OrIOp::create(fb, cleared, replicated); + tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + updated, readVisibilityType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createStageAccessForCommitCall( + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, Value pred, + MemType memType, CommitKind::Kind commitKind, Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.commits[commitKind].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + ValueType buffers = auxData.buffers[(int)memType].at(insertPoint); + ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + auto buffersType = cast(buffers.type); + auto commitsType = cast(outstandingCommits.type); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + SmallVector args = {bufOffset, lengthVal, + pred, threadVal, + buffers.value, outstandingCommits.value}; + createCallToCachedFunction( + b, "stage_access_for_commit", args, + /*assertInfo=*/std::nullopt, {buffersType, commitsType}, + [buffersType, commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value outstandingCommitsPtr = entryBlock->getArgument(5); + + (void)threadVal; + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value commits = tti::createLoadScratchMemory( + fb, fb.getLoc(), outstandingCommitsPtr, commitsType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + buffersEqBuf = + convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, commitsType); + Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + Value bufAndThread = + arith::AndIOp::create(fb, buffersEqBuf, threadColumnMask); + Value minusOne = + tti::createConstIntTensor(fb, fb.getLoc(), -1, commitsType, true); + Value updated = + arith::SelectOp::create(fb, bufAndThread, minusOne, commits); + tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + updated, commitsType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createCommitAccessesCall(ImplicitLocOpBuilder &b, + int thread, Value pred, + CommitKind::Kind commitKind, + Operation *insertPoint) { + if (auxData.commits[commitKind].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + auto commitsType = cast(outstandingCommits.type); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + SmallVector args = {threadVal, pred, outstandingCommits.value}; + createCallToCachedFunction( + b, "commit_accesses", args, + /*assertInfo=*/std::nullopt, {commitsType}, + [commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value threadVal = entryBlock->getArgument(0); + Value pred = entryBlock->getArgument(1); + Value outstandingCommitsPtr = entryBlock->getArgument(2); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value commits = tti::createLoadScratchMemory( + fb, fb.getLoc(), outstandingCommitsPtr, commitsType); + Type elementType = commitsType.getElementType(); + Value zero = arith::ConstantOp::create( + fb, elementType, fb.getIntegerAttr(elementType, 0)); + Value minusOne = arith::ConstantOp::create( + fb, elementType, fb.getIntegerAttr(elementType, -1)); + Value ones = tti::createConstIntTensor(fb, fb.getLoc(), 1, commitsType); + + Value threadMask = createColumnMask(fb, threadVal, commitsType); + auto commitsGtZero = createCmpIntTensorScalar( + fb, commits, zero, arith::CmpIPredicate::sgt); + commitsGtZero = arith::AndIOp::create(fb, commitsGtZero, threadMask); + Value commitsPlusOne = arith::AddIOp::create(fb, commits, ones); + commits = + arith::SelectOp::create(fb, commitsGtZero, commitsPlusOne, commits); + + auto commitsEqMinusOne = createCmpIntTensorScalar( + fb, commits, minusOne, arith::CmpIPredicate::eq); + commitsEqMinusOne = + arith::AndIOp::create(fb, commitsEqMinusOne, threadMask); + commits = arith::SelectOp::create(fb, commitsEqMinusOne, ones, commits); + + tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + commits, commitsType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearOutstandingCommitsTransferWritesCall( + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, + int outstandingNum, Value pred, CommitKind::Kind commitKind, + MemType memType, Operation *insertPoint) { + if (auxData.commits[commitKind].empty() || + auxData.writeVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + ValueType writeVisibility = + auxData.writeVisibility[(int)memType].at(insertPoint); + auto commitsType = cast(outstandingCommits.type); + auto writeVisibilityType = cast(writeVisibility.type); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value transferMaskVal = + arith::ConstantIntOp::create(b, transferThreadMask, 64); + Value outstandingNumVal = arith::ConstantIntOp::create(b, outstandingNum, 32); + SmallVector args = { + threadVal, transferMaskVal, outstandingNumVal, + pred, outstandingCommits.value, writeVisibility.value}; + createCallToCachedFunction( + b, "clear_outstanding_commits_transfer_writes", args, + /*assertInfo=*/std::nullopt, {commitsType, writeVisibilityType}, + [commitsType, writeVisibilityType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value threadVal = entryBlock->getArgument(0); + Value transferMaskVal = entryBlock->getArgument(1); + Value outstandingNumVal = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); + Value outstandingCommitsPtr = entryBlock->getArgument(4); + Value writeVisibilityPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value outstandingCommits = tti::createLoadScratchMemory( + fb, fb.getLoc(), outstandingCommitsPtr, commitsType); + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + + auto elemIntType = cast(commitsType.getElementType()); + Value outstandingNumElem = + adjustIntegerWidth(fb, outstandingNumVal, elemIntType); + Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + auto outstandingCommitsGtOutstandingNum = + createCmpIntTensorScalar(fb, outstandingCommits, outstandingNumElem, + arith::CmpIPredicate::sgt); + outstandingCommitsGtOutstandingNum = arith::AndIOp::create( + fb, outstandingCommitsGtOutstandingNum, threadColumnMask); + + Value rowMask = + createBitwiseOrReduce(fb, outstandingCommitsGtOutstandingNum, + /*axis=*/1); + rowMask = + createConvertLayout(fb, rowMask, writeVisibilityType.getEncoding()); + Value transferMaskElem = adjustIntegerWidth( + fb, transferMaskVal, + cast(writeVisibilityType.getElementType())); + Value transferMaskTensor = + triton::SplatOp::create(fb, writeVisibilityType, transferMaskElem); + Value writeVisibilityOrThreadBit = + arith::OrIOp::create(fb, writeVisibility, transferMaskTensor); + Value writeVisibilityUpdated = arith::SelectOp::create( + fb, rowMask, writeVisibilityOrThreadBit, writeVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + writeVisibilityUpdated, + writeVisibilityType); + + Value outstandingCommitsZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); + outstandingCommits = + arith::SelectOp::create(fb, outstandingCommitsGtOutstandingNum, + outstandingCommitsZero, outstandingCommits); + tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + outstandingCommits, commitsType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createClearOutstandingCommitsTransferReadsCall( + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, + int outstandingNum, Value pred, CommitKind::Kind commitKind, + MemType memType, Operation *insertPoint) { + if (auxData.commits[commitKind].empty() || + auxData.readVisibility[(int)memType].empty()) { + return; + } + ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + ValueType readVisibility = + auxData.readVisibility[(int)memType].at(insertPoint); + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + auto commitsType = cast(outstandingCommits.type); + auto readVisibilityType = cast(readVisibility.type); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value transferMaskVal = + arith::ConstantIntOp::create(b, transferThreadMask, 64); + Value outstandingNumVal = arith::ConstantIntOp::create(b, outstandingNum, 32); + SmallVector args = { + threadVal, transferMaskVal, outstandingNumVal, + pred, outstandingCommits.value, readVisibility.value}; + createCallToCachedFunction( + b, "clear_outstanding_commits_transfer_reads", args, + /*assertInfo=*/std::nullopt, {commitsType, readVisibilityType}, + [commitsType, readVisibilityType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { + Value threadVal = entryBlock->getArgument(0); + Value transferMaskVal = entryBlock->getArgument(1); + Value outstandingNumVal = entryBlock->getArgument(2); + Value pred = entryBlock->getArgument(3); + Value outstandingCommitsPtr = entryBlock->getArgument(4); + Value readVisibilityPtr = entryBlock->getArgument(5); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value outstandingCommits = tti::createLoadScratchMemory( + fb, fb.getLoc(), outstandingCommitsPtr, commitsType); + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + + auto elemIntType = cast(commitsType.getElementType()); + Value outstandingNumElem = + adjustIntegerWidth(fb, outstandingNumVal, elemIntType); + Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + auto outstandingCommitsGtOutstandingNum = + createCmpIntTensorScalar(fb, outstandingCommits, outstandingNumElem, + arith::CmpIPredicate::sgt); + outstandingCommitsGtOutstandingNum = arith::AndIOp::create( + fb, outstandingCommitsGtOutstandingNum, threadColumnMask); + + Value rowMask = + createBitwiseOrReduce(fb, outstandingCommitsGtOutstandingNum, + /*axis=*/1); + rowMask = + convertAndBroadcast(fb, rowMask, /*dim=*/1, readVisibilityType); + Value transferMaskElem = adjustIntegerWidth( + fb, transferMaskVal, + cast(readVisibilityType.getElementType())); + Value transferMaskTensor = + triton::SplatOp::create(fb, readVisibilityType, transferMaskElem); + Value readVisibilityOrThreadBit = + arith::OrIOp::create(fb, readVisibility, transferMaskTensor); + Value readVisibilityUpdated = arith::SelectOp::create( + fb, rowMask, readVisibilityOrThreadBit, readVisibility); + tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + readVisibilityUpdated, + readVisibilityType); + + Value outstandingCommitsZero = + tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); + outstandingCommits = + arith::SelectOp::create(fb, outstandingCommitsGtOutstandingNum, + outstandingCommitsZero, outstandingCommits); + tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + outstandingCommits, commitsType); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createCheckOutstandingCommitsCall( + ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, + StringRef pendingAccessType, Value pred, MemType memType, + CommitKind::Kind commitKind, Operation *insertPoint) { + if (auxData.buffers[(int)memType].empty() || + auxData.commits[commitKind].empty() || + (auxData.hasNonTrivialAliasing[(int)memType] && + auxData.aliasMatrices[(int)memType].empty())) { + return; + } + ValueType buffers = auxData.buffers[(int)memType].at(insertPoint); + ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); + assert(thread < NUM_THREADS && + "Commit-count tracking must operate on base threads"); + Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + auto buffersType = cast(buffers.type); + auto commitsType = cast(outstandingCommits.type); + Value threadVal = arith::ConstantIntOp::create(b, thread, 32); + Value lengthVal = arith::ConstantIntOp::create(b, length, 32); + std::string message = + "Accessing buffer with pending access. Pending access type: " + + pendingAccessType.str(); + AssertInfo assertInfo{message, + commitsType.cloneWith(std::nullopt, b.getI1Type())}; + Type aliasMatrixTypeBase; + auto buildCheckOutstandingCommitsBody = [&](bool useAlias) { + return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value bufOffset = entryBlock->getArgument(0); + Value lengthVal = entryBlock->getArgument(1); + Value pred = entryBlock->getArgument(2); + Value threadVal = entryBlock->getArgument(3); + Value buffers = entryBlock->getArgument(4); + Value outstandingCommitsPtr = entryBlock->getArgument(5); + Value aliasMatrix = useAlias ? entryBlock->getArgument(6) : Value(); + + Value outstandingCommits = tti::createLoadScratchMemory( + fb, fb.getLoc(), outstandingCommitsPtr, commitsType); + Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); + Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); + if (useAlias) { + buffersEqBuf = + expandAliases(fb, buffersEqBuf, aliasMatrix, + cast(aliasMatrixTypeBase)); + } + buffersEqBuf = + convertAndBroadcast(fb, buffersEqBuf, /*dim=*/1, commitsType); + Value zeroTensor = + tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); + Value selectedRows = arith::SelectOp::create( + fb, buffersEqBuf, outstandingCommits, zeroTensor); + Value selectedEqZero = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + selectedRows, zeroTensor); + Value vTrue = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(selectedEqZero.getType())); + Value predicatedSelectedEqZero = + arith::SelectOp::create(fb, pred, selectedEqZero, vTrue); + + triton::ReturnOp::create(fb, predicatedSelectedEqZero); + }; + }; + if (auxData.hasNonTrivialAliasing[(int)memType]) { + ValueType aliasMatrix = auxData.aliasMatrices[(int)memType].at(insertPoint); + aliasMatrixTypeBase = aliasMatrix.type; + auto aliasMatrixType = cast(aliasMatrixTypeBase); + SmallVector args = { + bufOffset, lengthVal, pred, + threadVal, buffers.value, outstandingCommits.value, + aliasMatrix.value}; + createCallToCachedFunction( + b, "check_outstanding_commits", args, assertInfo, + {buffersType, commitsType, aliasMatrixType, (int)thread}, + buildCheckOutstandingCommitsBody(/*useAlias=*/true)); + } else { + SmallVector args = {bufOffset, lengthVal, + pred, threadVal, + buffers.value, outstandingCommits.value}; + createCallToCachedFunction( + b, "check_outstanding_commits_noalias", args, assertInfo, + {buffersType, commitsType, (int)thread}, + buildCheckOutstandingCommitsBody(/*useAlias=*/false)); + } +} + +} // namespace mlir::triton::instrument diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Ops.cpp new file mode 100644 index 0000000000..823cc8649b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Ops.cpp @@ -0,0 +1,8 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Utility.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonInstrument/IR/Ops.cpp.inc" + +#include "triton/Dialect/TritonInstrument/IR/OpsEnums.cpp.inc" diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Utility.cpp b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Utility.cpp new file mode 100644 index 0000000000..9ca13dce6e --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -0,0 +1,581 @@ +#include "triton/Dialect/TritonInstrument/IR/Utility.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Analysis/BufferRegion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" + +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; +using namespace mlir::triton::instrument; +using mlir::triton::BufferRegion; + +namespace { + +BlockedEncodingAttr getThreadLocalBlockedEncoding(MLIRContext *ctx, + unsigned int size, + unsigned int warps, + unsigned int numCTAs) { + auto cgaLayout = CGAEncodingAttr::get1DLayout(ctx, numCTAs); + return BlockedEncodingAttr::get(ctx, + /*sizePerThread=*/{size}, + /*threadsPerWarp=*/{32}, + /*warpsPerCTA=*/{warps}, + /*order=*/{0}, cgaLayout); +} + +BlockedEncodingAttr getThreadLocalBlockedEncoding(MLIRContext *ctx, + unsigned int buffers, + unsigned int barriers, + unsigned int warps, + unsigned int numCTAs) { + auto kBlocks = StringAttr::get(ctx, "block"); + auto dims = standardOutDimNames(ctx, 2); + auto ll = LinearLayout::identity1D(1, kBlocks, dims[0]) * + LinearLayout::identity1D(numCTAs, kBlocks, dims[1]); + auto cgaLayout = CGAEncodingAttr::get(ctx, std::move(ll)); + return BlockedEncodingAttr::get(ctx, + /*sizePerThread=*/{buffers, barriers}, + /*threadsPerWarp=*/{1, 32}, + /*warpsPerCTA=*/{1, warps}, + /*order=*/{0, 1}, std::move(cgaLayout)); +} + +RankedTensorType getIntTensorType(Region *region, ArrayRef shape, + unsigned bitWidth) { + MLIRContext *ctx = region->getContext(); + unsigned int warps = lookupNumWarps(region); + unsigned int numCTAs = lookupNumCTAs(region->getParentOp()); + BlockedEncodingAttr encoding; + if (shape.size() == 1) { + encoding = getThreadLocalBlockedEncoding( + ctx, static_cast(shape[0]), warps, numCTAs); + } else { + assert(shape.size() == 2 && "Only 1D and 2D shapes are supported"); + encoding = getThreadLocalBlockedEncoding( + ctx, static_cast(shape[0]), static_cast(shape[1]), + warps, numCTAs); + } + Type elType = IntegerType::get(ctx, bitWidth); + return RankedTensorType::get(shape, elType, encoding); +} + +std::pair +createBufferDescriptorsTensor(ImplicitLocOpBuilder &builder, MemType memType, + ArrayRef regions) { + int64_t size = regions.size(); + assert(llvm::isPowerOf2_64(size) && "Expected power of 2"); + auto tensorType = + getIntTensorType(builder.getInsertionBlock()->getParent(), {size}, 64); + SmallVector offsets; + SmallVector lengths; + offsets.reserve(size); + lengths.reserve(size); + for (const auto ®ion : regions) { + offsets.push_back(static_cast(region.baseOffset)); + lengths.push_back(static_cast(region.length)); + } + return {ExperimentalBufferDescriptorsOp::create(builder, tensorType, offsets, + lengths, memType), + tensorType}; +} + +SmallVector> +createAliasingMatrix(ArrayRef regions) { + SmallVector> matrix; + size_t numRegions = regions.size(); + matrix.resize(numRegions); + for (size_t i = 0; i < numRegions; ++i) + matrix[i].assign(numRegions, /*Value=*/0); + + for (size_t i = 0; i < numRegions; ++i) { + uint64_t startI = regions[i].baseOffset; + uint64_t endI = startI + regions[i].length; + if (regions[i].length == 0) + continue; + // Include self-aliasing + for (size_t j = i; j < numRegions; ++j) { + uint64_t startJ = regions[j].baseOffset; + uint64_t endJ = startJ + regions[j].length; + if (regions[j].length == 0) + continue; + bool alias = (startI < endJ) && (startJ < endI); + if (alias) { + matrix[i][j] = 1; + matrix[j][i] = 1; + } + } + } + return matrix; +} + +bool hasCrossBufferAliasing(ArrayRef regions) { + size_t numRegions = regions.size(); + for (size_t i = 0; i < numRegions; ++i) { + if (regions[i].length == 0) + continue; + uint64_t startI = regions[i].baseOffset; + uint64_t endI = startI + regions[i].length; + for (size_t j = i + 1; j < numRegions; ++j) { + if (regions[j].length == 0) + continue; + uint64_t startJ = regions[j].baseOffset; + uint64_t endJ = startJ + regions[j].length; + if ((startI < endJ) && (startJ < endI)) { + return true; + } + } + } + return false; +} + +Value createInitializedScratchMemory(ImplicitLocOpBuilder &b, + TypedValue tensor) { + Type elType = tensor.getType().getElementType(); + int elSize = elType.getIntOrFloatBitWidth() / 8; + int numEls = product(tensor.getType().getShape()); + int64_t sizeInBytes = numEls * elSize; + Type ptrType = triton::getPointerType(elType); + auto alloc = GlobalScratchAllocOp::create(b, ptrType, sizeInBytes, elSize); + createStoreScratchMemory(b, b.getLoc(), alloc, tensor, tensor.getType()); + return alloc; +} + +Value createZeroInitStateTensor(ImplicitLocOpBuilder &b, int m, int n, + int bitWidth) { + SmallVector shape = {m}; + if (n > 0) { + shape.push_back(n); + } + auto type = + getIntTensorType(b.getInsertionBlock()->getParent(), shape, bitWidth); + TypedValue tensor = + createConstIntTensor(b, b.getLoc(), 0, type); + return createInitializedScratchMemory(b, tensor); +} + +TypedValue +createAliasMatrixTensor(ImplicitLocOpBuilder &b, + ArrayRef> matrix, Region *region) { + size_t rows = matrix.size(); + if (rows == 0) + return {}; + size_t cols = matrix.front().size(); + for (const auto &row : matrix) + assert(row.size() == cols && "Expected square alias matrix"); + + auto type = getIntTensorType( + region, {static_cast(rows), static_cast(cols)}, + /*bitWidth=*/1); + SmallVector values; + values.reserve(rows * cols); + for (const auto &row : matrix) + for (uint8_t v : row) + values.emplace_back(/*numBits=*/1, v); + + auto denseAttr = DenseElementsAttr::get(type, values); + Value constValue = arith::ConstantOp::create(b, b.getLoc(), type, denseAttr); + return cast>(constValue); +} + +bool hasCpAsync(ModuleOp module) { + bool hasCpAsync = false; + module.walk([&](Operation *op) { + if (isa(op)) { + hasCpAsync = true; + } + }); + return hasCpAsync; +} + +bool hasWGMMA(ModuleOp module) { + bool hasWGMMA = false; + module.walk([&](Operation *op) { + if (isa(op)) { + hasWGMMA = true; + } + }); + return hasWGMMA; +} + +bool hasTMAStore(ModuleOp module) { + bool hasTMAStore = false; + module.walk([&](Operation *op) { + if (isa(op)) { + hasTMAStore = true; + } + }); + return hasTMAStore; +} + +Value createLockVariable(ImplicitLocOpBuilder &b) { + Type ptrType = triton::getPointerType(b.getI32Type()); + auto alloc = GlobalScratchAllocOp::create(b, ptrType, 4, 4); + Value zero = arith::ConstantOp::create(b, b.getLoc(), b.getI32Type(), + b.getI32IntegerAttr(0)); + triton::AtomicRMWOp::create(b, b.getI32Type(), RMWOp::XCHG, alloc, zero, + nullptr, MemSemantic::ACQUIRE_RELEASE, + MemSyncScope::GPU); + return alloc; +} + +} // namespace + +namespace mlir::triton::instrument { + +TypedValue createConstIntTensor(OpBuilder &builder, + Location loc, int64_t val, + RankedTensorType tensorType, + bool isSigned /*= false*/) { + int bitWidth = tensorType.getElementType().getIntOrFloatBitWidth(); + auto denseAttr = + DenseElementsAttr::get(tensorType, APInt(bitWidth, val, isSigned)); + return cast>( + arith::ConstantOp::create(builder, loc, tensorType, denseAttr) + .getResult()); +} + +DistributedEncodingTrait getSingleDimSliceEncoding(BlockedEncodingAttr encoding, + int dim) { + int rank = encoding.getOrder().size(); + MLIRContext *ctx = encoding.getContext(); + assert(dim < rank && "Expected dim to be less than rank"); + DistributedEncodingTrait sliceEncoding = encoding; + for (int i = 0; i < rank; ++i) { + if (i != dim) { + sliceEncoding = SliceEncodingAttr::get(ctx, i, sliceEncoding); + } + } + return sliceEncoding; +} + +Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor) { + auto type = cast(tensor.getType()); + auto sliceEncoding = dyn_cast(type.getEncoding()); + if (sliceEncoding) { + int dim = sliceEncoding.getDim(); + auto shape = type.getShape(); + auto newShape = SmallVector(shape); + newShape.insert(newShape.begin() + dim, 1); + auto newType = RankedTensorType::get(newShape, type.getElementType(), + sliceEncoding.getParent()); + tensor = ExpandDimsOp::create(b, loc, newType, tensor, dim); + } + return tensor; +} + +static Value expandAllSlicedDims(OpBuilder &b, Location loc, Value tensor) { + auto type = cast(tensor.getType()); + auto sliceEncoding = dyn_cast(type.getEncoding()); + while (sliceEncoding) { + tensor = expandOuterSlicedDim(b, loc, tensor); + type = cast(tensor.getType()); + sliceEncoding = dyn_cast(type.getEncoding()); + } + return tensor; +} + +static Value createPointerTensor(OpBuilder &b, Location loc, Value base, + RankedTensorType tensorType) { + auto encoding = cast(tensorType.getEncoding()); + Value ptrTensor = SplatOp::create( + b, loc, + RankedTensorType::get(tensorType.getShape(), base.getType(), encoding), + base); + auto offsetsType = + RankedTensorType::get(tensorType.getShape(), b.getI32Type(), encoding); + SmallVector strides(tensorType.getRank()); + strides[0] = 1; + for (int i = 1; i < tensorType.getRank(); ++i) { + strides[i] = strides[i - 1] * tensorType.getShape()[i - 1]; + } + for (int i = 0; i < tensorType.getRank(); ++i) { + auto partialEncoding = getSingleDimSliceEncoding(encoding, i); + auto arangeType = RankedTensorType::get({tensorType.getShape()[i]}, + b.getI32Type(), partialEncoding); + auto arange = + MakeRangeOp::create(b, loc, arangeType, 0, arangeType.getShape()[0]); + auto cstStride = createConstIntTensor(b, loc, strides[i], arangeType); + auto arangeTimesStride = + arith::MulIOp::create(b, loc, arangeType, arange, cstStride); + auto expandDims = expandAllSlicedDims(b, loc, arangeTimesStride); + if (cast(expandDims.getType()).getShape() != + tensorType.getShape()) { + expandDims = BroadcastOp::create(b, loc, offsetsType, expandDims); + } + ptrTensor = + AddPtrOp::create(b, loc, ptrTensor.getType(), ptrTensor, expandDims); + } + return ptrTensor; +} + +Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc, + Value tensor, RankedTensorType tensorType) { + auto ptrTensor = createPointerTensor(b, loc, alloc, tensorType); + return StoreOp::create(b, loc, ptrTensor, tensor, CacheModifier::NONE, + EvictionPolicy::NORMAL); +} + +Value createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc, + RankedTensorType tensorType) { + auto ptrTensor = createPointerTensor(b, loc, alloc, tensorType); + return LoadOp::create(b, loc, ptrTensor, CacheModifier::NONE, + EvictionPolicy::NORMAL, false); +} + +FuncOp getEntryPoint(ModuleOp module) { + SmallVector publicFuncs = llvm::to_vector(llvm::make_filter_range( + module.getOps(), [](FuncOp func) { return func.isPublic(); })); + assert(publicFuncs.size() == 1 && "Expected exactly one public function"); + return publicFuncs.front(); +} + +Region *AuxDataMap::RegionToValueMap::getEnclosingParitionOrFunctionRegion( + Operation *op) { + Region *region = op->getParentRegion(); + while (region) { + if (auto wsOp = dyn_cast(region->getParentOp())) { + if (region == &wsOp.getDefaultRegion()) { + return getEnclosingParitionOrFunctionRegion(wsOp); + } + return region; + } + if (auto wsOp = + dyn_cast(region->getParentOp())) { + return region; + } + if (isa(region->getParentOp())) { + ModuleOp module = op->getParentOfType(); + assert(getEntryPoint(module) == region->getParentOp() && + "Concurrency sanitizer supports only one instrumented " + "function in the module"); + return region; + } + region = region->getParentRegion(); + } + llvm_unreachable("Expected to find enclosing partition or function region"); + return nullptr; +} + +void AuxDataMap::populateAndPassToWarpSpecialize(ModuleOp module) { + SmallVector, numMemTypes> bufRegions(numMemTypes); + SmallVector barrierRegions; + getBuffersAndBarriers(module, bufRegions, barrierRegions); + + FuncOp entryPoint = getEntryPoint(module); + assert(entryPoint); + Region *entryRegion = &entryPoint.getBody(); + + ImplicitLocOpBuilder b(entryPoint.getLoc(), entryPoint); + b.setInsertionPointToStart(&entryPoint.getBody().front()); + + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + int iMemType = (int)memType; + if (bufRegions[iMemType].empty()) { + continue; + } + + buffers[iMemType].insert( + entryRegion, + {createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}); + // Buffer descriptors are rematerialized in the warp specialize region, + // not passed as an argument. + createInWarpSpecialize( + entryPoint, buffers[iMemType], [&](ImplicitLocOpBuilder &b) { + return ValueType{ + createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}; + }); + int numBufs = bufRegions[iMemType].size(); + + hasNonTrivialAliasing[iMemType] = + hasCrossBufferAliasing(bufRegions[iMemType]); + if (hasNonTrivialAliasing[iMemType]) { + auto aliasMatrixData = createAliasingMatrix(bufRegions[iMemType]); + if (!aliasMatrixData.empty()) { + auto aliasTensor = + createAliasMatrixTensor(b, aliasMatrixData, entryRegion); + aliasMatrices[iMemType].insert(entryRegion, + {aliasTensor, aliasTensor.getType()}); + createInWarpSpecialize( + entryPoint, aliasMatrices[iMemType], + [aliasMatrixData](ImplicitLocOpBuilder &nestedBuilder) { + Region *region = nestedBuilder.getInsertionBlock()->getParent(); + auto tensor = createAliasMatrixTensor(nestedBuilder, + aliasMatrixData, region); + return ValueType{tensor, tensor.getType()}; + }); + } + } + + writeVisibility[iMemType].insert( + entryRegion, {createZeroInitStateTensor(b, numBufs, 0, 64), + getIntTensorType(entryRegion, {numBufs}, 64)}); + passToWarpSpecialize(entryPoint, writeVisibility[iMemType].at(entryRegion), + writeVisibility[iMemType]); + readVisibility[iMemType].insert( + entryRegion, + {createZeroInitStateTensor(b, numBufs, THREADS_BITMASK_SIZE, 64), + getIntTensorType(entryRegion, {numBufs, THREADS_BITMASK_SIZE}, 64)}); + passToWarpSpecialize(entryPoint, readVisibility[iMemType].at(entryRegion), + readVisibility[iMemType]); + } + + if (!barrierRegions.empty()) { + // Barriers allocations are in shared memory + barriers.insert(entryRegion, {createBufferDescriptorsTensor( + b, MemType::SHARED_MEM, barrierRegions)}); + // Barriers allocations are rematerialized in the warp specialize region, + // not passed as an argument. + createInWarpSpecialize(entryPoint, barriers, [&](ImplicitLocOpBuilder &b) { + return ValueType{createBufferDescriptorsTensor(b, MemType::SHARED_MEM, + barrierRegions)}; + }); + + int numBarriers = barrierRegions.size(); + barrierStates.insert(entryRegion, + {createZeroInitStateTensor(b, numBarriers, 0, 32), + getIntTensorType(entryRegion, {numBarriers}, 32)}); + passToWarpSpecialize(entryPoint, barrierStates.at(entryRegion), + barrierStates); + + // Deadlock detection aux data: waiting (i32[K]) storing waiting flag and + // phase bits per thread (two bits per thread). + waiting.insert(entryRegion, + {createZeroInitStateTensor(b, numBarriers, 0, 32), + getIntTensorType(entryRegion, {numBarriers}, 32)}); + passToWarpSpecialize(entryPoint, waiting.at(entryRegion), waiting); + + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + int iMemType = (int)memType; + // Create state tensors: + int numBufs = bufRegions[iMemType].size(); + int numBarriers = barrierRegions.size(); + if (numBufs > 0) { + writeTracking[iMemType].insert( + entryRegion, + {createZeroInitStateTensor(b, numBufs, numBarriers, 8), + getIntTensorType(entryRegion, {numBufs, numBarriers}, 8)}); + passToWarpSpecialize(entryPoint, + writeTracking[iMemType].at(entryRegion), + writeTracking[iMemType]); + readTracking[iMemType].insert( + entryRegion, + {createZeroInitStateTensor(b, numBufs, numBarriers, 64), + getIntTensorType(entryRegion, {numBufs, numBarriers}, 64)}); + passToWarpSpecialize(entryPoint, readTracking[iMemType].at(entryRegion), + readTracking[iMemType]); + } + } + } + + // Create lock variable allocation + Value lockVal = createLockVariable(b); + lock.insert(entryRegion, {lockVal, lockVal.getType()}); + passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock); + + auto createCommitTensor = [&](CommitKind::Kind commitKind) { + int numBufs = bufRegions[(int)MemType::SHARED_MEM].size(); + if (numBufs == 0) + return; + // NUM_THREADS instead of THREADS_BITMASK_SIZE as commit-count tracking + // operates on base threads. + commits[commitKind].insert( + entryRegion, + {createZeroInitStateTensor(b, numBufs, NUM_THREADS, 8), + getIntTensorType(entryRegion, {numBufs, NUM_THREADS}, 8)}); + passToWarpSpecialize(entryPoint, commits[commitKind].at(entryRegion), + commits[commitKind]); + }; + + // Create write commits tensor for cp-async + if (hasCpAsync(module)) { + createCommitTensor(CommitKind::AsyncCp); + } + + // Create reads commits tensor for wgmma + if (hasWGMMA(module)) { + createCommitTensor(CommitKind::Wgmma); + } + + if (hasTMAStore(module)) { + createCommitTensor(CommitKind::TmaStore); + } +} + +void AuxDataMap::getBuffersAndBarriers( + ModuleOp module, SmallVector, 2> &bufRegions, + SmallVector &barrierRegions) { + // Collect shared memory buffers allocated in the module + std::unique_ptr solver = createDataFlowSolver(); + triton::BufferRegionAnalysis *analysis = + solver->load(); + if (failed(solver->initializeAndRun(module))) + return; + + analysis->calculateUsedBufferRegions(module); + bufRegions[(int)MemType::SHARED_MEM] = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::SHARED_MEMORY); + bufRegions[(int)MemType::TENSOR_MEM] = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::TENSOR_MEMORY); + barrierRegions = analysis->getAllUsedBufferRegions( + BufferRegionAnalysis::RegionType::BARRIER); + + if (!barrierRegions.empty()) { + barrierRegions.resize(llvm::NextPowerOf2(barrierRegions.size() - 1), + BufferRegion{0, 0}); + } + + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + int iMemType = (int)memType; + if (bufRegions[iMemType].empty()) { + continue; + } + bufRegions[iMemType].resize( + llvm::NextPowerOf2(bufRegions[iMemType].size() - 1), + BufferRegion{0, 0}); + } +} + +void AuxDataMap::passToWarpSpecialize(FuncOp func, ValueType valueType, + RegionToValueMap &map) { + func.walk([&](WarpSpecializePartitionsOp op) { + op->insertOperands(op.getNumOperands(), {valueType.value}); + for (Region ®ion : op.getPartitionRegions()) { + // Pass the value as a pointer type (instead of the type of underlying + // memory) + region.addArgument(valueType.value.getType(), op.getLoc()); + Type newType = valueType.type; + if (auto tensorType = dyn_cast(newType)) { + // If this is a tensor, make sure the layout matches the region's warp + // count + newType = getIntTensorType( + ®ion, tensorType.getShape(), + tensorType.getElementType().getIntOrFloatBitWidth()); + } + map.insert( + ®ion, + ValueType{region.getArgument(region.getNumArguments() - 1), newType}); + } + }); +} + +void AuxDataMap::createInWarpSpecialize( + FuncOp func, RegionToValueMap &map, + std::function createFn) { + func.walk([&](WarpSpecializeOp op) { + for (Region *region : op.getPartitionRegions()) { + ImplicitLocOpBuilder b(region->getLoc(), region); + b.setInsertionPointToStart(®ion->getBlocks().front()); + map.insert(region, createFn(b)); + } + }); +} + +} // namespace mlir::triton::instrument diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..62116e5927 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonInstrumentTransforms + ConcurrencySanitizer.cpp + + DEPENDS + TritonInstrumentTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonIR + TritonGPUIR + TritonNvidiaGPUIR + TritonToTritonGPU + TritonInstrumentIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp new file mode 100644 index 0000000000..a29d9cdcf9 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -0,0 +1,581 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/FunctionBuilder.h" +#include "triton/Dialect/TritonInstrument/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +// clang-format off +// Concurrency Sanitizer data structures: +// ConSan keeps auxilary data requied for tracking memory accesses in tensors. +// These tensors are stored as a distributed tensor or in global scratch memory. +// +// Name | Storage | Rank/Type | Description +// ------------------|---------|-----------------|------------ +// buffers | tensor | | Base pointers of all (sub)buffers +// barriers | tensor | | Pointers to all individual mbarriers +// barrierStates | scratch | | Packed barrier phase (bit 0) and arrival counts (bits[1..8] init, [9..16] current) +// waiting | scratch | | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1) +// writeVisibility | scratch | | Per-buffer thread-visibility bitmask (bit i => thread i visible) +// readVisibility | scratch | | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks) +// writeTracking | scratch | | Map buffers -> barriers that track writes +// readTracking | scratch | | Map buffers -> barriers that track reads +// outstandingCommits +// (async/wgmma) | scratch | | Number of outstanding commits per buffer/thread (2D replaces prior 1D) +// clang-format on + +namespace mlir { +namespace triton { +namespace instrument { + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace tti = mlir::triton::instrument; + +#define GEN_PASS_DEF_TRITONINSTRUMENTCONCURRENCYSANITIZER +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h.inc" + +namespace { + +// OpBuilder listener tracking operations added to the builder to be wrapped +// with a lock acquire/release pair. +class CriticalSectionListener : public ImplicitLocOpBuilder::Listener { +public: + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint /*previous*/) override { + if (firstOp == nullptr) { + firstOp = op; + } + lastOp = op; + } + void maybeWrapWithCriticalSection(ImplicitLocOpBuilder &b, + AuxDataMap &auxData, Value pred) { + Operation *_firstOp = firstOp; + Operation *_lastOp = lastOp; + if (firstOp != nullptr && lastOp != nullptr) { + assert(firstOp->getParentRegion() == lastOp->getParentRegion()); + b.setInsertionPoint(_firstOp); + tti::ExperimentalLockAcquireOp::create(b, auxData.lock.at(_firstOp).value, + pred); + b.setInsertionPointAfter(_lastOp); + tti::ExperimentalLockReleaseOp::create(b, auxData.lock.at(_firstOp).value, + pred); + } + } + +private: + Operation *firstOp = nullptr; + Operation *lastOp = nullptr; +}; + +bool isTMAOp(Operation *op) { + return isa(op); +} + +bool isTensorCoreOp(Operation *op) { + return isa( + op); +} + +std::optional maybeGetPartitionIdx(Operation *op) { + if (auto wsOp = op->getParentOfType()) { + return op->getParentRegion()->getRegionNumber(); + } + if (Operation *parent = op->getParentOp()) { + return maybeGetPartitionIdx(parent); + } + return std::nullopt; +} + +int getCurrentThread(Operation *op) { + // Default partition is 0, other partitions are idx + 1 + int thread = maybeGetPartitionIdx(op).value_or(-1) + 1; + if (isTMAOp(op)) { + thread += TMA_THREAD_OFFSET; + return thread; + } + if (isTensorCoreOp(op)) { + thread += TC_THREAD_OFFSET; + return thread; + } + return thread; +} + +int getBaseThread(int thread) { return thread % NUM_THREADS; } + +// Peer threads are the equivalent threads in the TMA, TC and normal +// thread classes. +// If a thread is a base thread, return the mask with the peers, otherwise +// return the mask with the thread itself. +uint64_t getThreadPeersMask(int thread) { + uint64_t mask = 1ULL << thread; + if (thread < NUM_THREADS) { + mask |= 1ULL << (thread + TMA_THREAD_OFFSET); + mask |= 1ULL << (thread + TC_THREAD_OFFSET); + } + return mask; +} + +int getActiveMask(Operation *op) { + int numParts = 1; + + if (auto wsOp = op->getParentOfType()) { + numParts = wsOp.getPartitionRegions().size() + 1; + } + if (auto wsOp = op->getParentOfType()) { + numParts = wsOp.getPartitionRegions().size() + 1; + } + int activeMask = 0; + for (int i = 0; i < numParts; ++i) + activeMask |= (1 << i); + return activeMask; +} + +uint32_t getMemDescLength(Value buf) { + auto memDescType = cast(buf.getType()); + if (isa(memDescType.getEncoding())) { + unsigned elSize = memDescType.getElementType().getIntOrFloatBitWidth() / 8; + return static_cast(product(memDescType.getShape()) * elSize); + } + if (isa(memDescType.getMemorySpace())) { + return ttng::getTmemAllocSizes(memDescType).numCols; + } + llvm_unreachable("Unsupported memory space for memdesc"); +} + +} // namespace + +class ConcurrencySanitizerPass + : public impl::TritonInstrumentConcurrencySanitizerBase< + ConcurrencySanitizerPass> { +public: + void runOnOperation() override { + module = getOperation(); + + auxData.populateAndPassToWarpSpecialize(module); + + tt::FuncOp entryPoint = tti::getEntryPoint(module); + + ImplicitLocOpBuilder b(entryPoint.getLoc(), entryPoint); + b.setInsertionPointToStart(&entryPoint.getBody().front()); + instrumentMemoryOperations(b); + } + +private: + void instrumentMemoryOperations(ImplicitLocOpBuilder &b) { + tti::FunctionBuilder funcBuilder(module, auxData); + module.walk([&](Operation *op) { + CriticalSectionListener listener; + b.setListener(&listener); + + int thread = getCurrentThread(op); + int baseThread = getBaseThread(thread); + b.setLoc(op->getLoc()); + b.setInsertionPoint(op); + if (isa(op)) { + // Place insert point after specific ops: + // allocs - we want to + // check if it is not overwriting any earlier allocation, but the + // memref value can be referenced only after it is created. + // wait barriers - we can update aux data only after the wait is + // completed + b.setInsertionPointAfter(op); + } + + instrumentMemEffects(b, op, thread, funcBuilder); + b.setLoc(op->getLoc()); + if (auto wsOp = dyn_cast(op)) { + auto partitionRegions = wsOp.getPartitionRegions(); + if (!partitionRegions.empty()) { + uint64_t destMask = 0; + for (size_t idx = 0, e = partitionRegions.size(); idx < e; ++idx) + destMask |= getThreadPeersMask(idx + 1); + if (destMask) { + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + funcBuilder.createCopyWriteVisibilityCall(b, thread, destMask, + nullptr, memType, op); + funcBuilder.createCopyReadVisibilityCall(b, thread, destMask, + nullptr, memType, op); + } + } + } + } + if (auto initOp = dyn_cast(op)) { + funcBuilder.createInitBarrierStateCall(b, initOp.getAlloc(), + initOp.getCount(), initOp); + } + if (auto waitOp = dyn_cast(op)) { + // Pre-wait: mark waiting threads and check for deadlock. + { + CriticalSectionListener preListener; + b.setListener(&preListener); + b.setInsertionPoint(waitOp); + auto pred = waitOp.getPred(); + auto barrier = waitOp.getAlloc(); + funcBuilder.createSetWaitingCall(b, barrier, baseThread, + waitOp.getPhase(), pred, waitOp); + funcBuilder.createCheckAllActiveWaitingCall(b, getActiveMask(op), + pred, waitOp); + + preListener.maybeWrapWithCriticalSection(b, auxData, pred); + b.setListener(&listener); + b.setInsertionPointAfter(waitOp); + } + // Post-wait: transfer visible writes and reads to all peer threads, + // and clear waiting for this barrier + auto _barriers = auxData.barriers.at(op).value; + assert(!auxData.barriers.empty()); + auto pred = waitOp.getPred(); + auto barrier = waitOp.getAlloc(); + + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + // Transfer visible writes and reads to all peer threads + funcBuilder.createTransferVisibleWritesCall( + b, barrier, getThreadPeersMask(thread), pred, memType, op); + funcBuilder.createTransferVisibleReadsCall( + b, barrier, getThreadPeersMask(thread), pred, memType, op); + } + funcBuilder.createClearWaitingCall(b, barrier, baseThread, pred, + waitOp); + } + if (auto asyncCommitGroupOp = dyn_cast(op)) { + if (!auxData.commits[CommitKind::AsyncCp].empty()) + funcBuilder.createCommitAccessesCall(b, thread, nullptr, + CommitKind::AsyncCp, op); + } + if (auto asyncWaitOp = dyn_cast(op)) { + funcBuilder.createClearOutstandingCommitsTransferWritesCall( + b, baseThread, getThreadPeersMask(thread), asyncWaitOp.getNum(), + nullptr, CommitKind::AsyncCp, MemType::SHARED_MEM, op); + } + if (auto wgmmaWaitOp = dyn_cast(op)) { + funcBuilder.createClearOutstandingCommitsTransferReadsCall( + b, baseThread, getThreadPeersMask(thread), + wgmmaWaitOp.getPendings(), nullptr, CommitKind::Wgmma, + MemType::SHARED_MEM, op); + } + if (auto tmaStoreWaitOp = dyn_cast(op)) { + funcBuilder.createClearOutstandingCommitsTransferReadsCall( + b, baseThread, getThreadPeersMask(thread), + tmaStoreWaitOp.getPendings(), nullptr, CommitKind::TmaStore, + MemType::SHARED_MEM, op); + } + listener.maybeWrapWithCriticalSection(b, auxData, nullptr); + b.setListener(nullptr); + }); + } + + struct MemEffectsOpInfo { + struct Effects { + enum RW { Read, Write } rw; + Value buf; + std::string operandName = ""; + uint32_t length = 0; + + Effects(RW rw, Value buf, std::string operandName = "") + : rw(rw), buf(buf), operandName(operandName), + length(getMemDescLength(buf)) {} + }; + struct BarrierInfo { + Value barrier; + Value pred; + int count; + }; + enum class TrackingKind { + None, + Barrier, + wgmmaCommit, + CommitCount + } trackingKind = TrackingKind::None; + + CommitKind::Kind commitKind = CommitKind::None; + + SmallVector barriers; + Value pred; + SmallVector operandEffects; + bool implicitCommit = false; + }; + + void instrumentMemEffects(ImplicitLocOpBuilder &b, Operation *op, int thread, + tti::FunctionBuilder &funcBuilder) { + int baseThread = getBaseThread(thread); + std::optional opInfo = getMemEffectsOpInfo(op); + if (!opInfo) { + return; + } + Value pred = opInfo->pred; + auto combinePredicates = [&](Value barrierPred) -> Value { + if (barrierPred && pred) { + return arith::AndIOp::create(b, b.getLoc(), barrierPred, pred); + } + return barrierPred ? barrierPred : pred; + }; + for (auto effect : opInfo->operandEffects) { + Value buf = effect.buf; + auto bufType = cast(buf.getType()); + MemType memType = MemType::TENSOR_MEM; + if (isa(bufType.getEncoding())) { + memType = MemType::SHARED_MEM; + } + if (effect.rw == MemEffectsOpInfo::Effects::Read) { + // For op that is reading, we only need to check if anything else + // is writing to the same buffer. + addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); + if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { + funcBuilder.createSetReadVisibilityCall(b, buf, effect.length, + getThreadPeersMask(thread), + pred, memType, op); + } + if (opInfo->trackingKind == + MemEffectsOpInfo::TrackingKind::CommitCount) { + assert(memType == MemType::SHARED_MEM); + funcBuilder.createStageAccessForCommitCall(b, buf, effect.length, + baseThread, pred, memType, + opInfo->commitKind, op); + } + } + if (effect.rw == MemEffectsOpInfo::Effects::Write) { + // Op is writing to the buffer, we need to check if anything else + // is reading or writing to the same buffer. + addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); + addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType, + thread, effect.operandName); + if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { + funcBuilder.createSetWriteVisibilityCall(b, buf, effect.length, + getThreadPeersMask(thread), + pred, memType, op); + funcBuilder.createClearWriteTrackingCall(b, buf, effect.length, pred, + memType, op); + funcBuilder.createClearReadVisibilityCall(b, buf, effect.length, pred, + memType, op); + funcBuilder.createClearReadTrackingCall(b, buf, effect.length, pred, + memType, op); + } + if (opInfo->trackingKind == + MemEffectsOpInfo::TrackingKind::CommitCount) { + assert(memType == MemType::SHARED_MEM); + funcBuilder.createStageAccessForCommitCall(b, buf, effect.length, + baseThread, pred, memType, + opInfo->commitKind, op); + } + } + } + for (const auto &barrierInfo : opInfo->barriers) { + Value barrier = barrierInfo.barrier; + Value combinedPred = combinePredicates(barrierInfo.pred); + // If the op has barriers, we treat it as a commit emitted for each + // barrier. + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + funcBuilder.createTrackVisibleWritesCall(b, barrier, thread, + combinedPred, memType, op); + funcBuilder.createTrackVisibleReadsCall(b, barrier, thread, + combinedPred, memType, op); + } + if (barrierInfo.count > 0) { + funcBuilder.createVerifyBarrierArriveCall(b, barrier, barrierInfo.count, + combinedPred, op); + funcBuilder.createUpdateBarrierStateCall(b, barrier, barrierInfo.count, + combinedPred, op); + } + } + if (opInfo->implicitCommit) { + assert(opInfo->trackingKind == + MemEffectsOpInfo::TrackingKind::CommitCount); + funcBuilder.createCommitAccessesCall(b, baseThread, pred, + opInfo->commitKind, op); + } + } + + void addWriteChecks(ImplicitLocOpBuilder &b, + tti::FunctionBuilder &funcBuilder, Operation *op, + Value buf, uint32_t length, Value pred, MemType memType, + int thread, const std::string &operandName) { + funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op); + // commit-num-based synchronization is only supported for shared memory + if (memType == MemType::SHARED_MEM) { + funcBuilder.createCheckOutstandingCommitsCall( + b, buf, length, getBaseThread(thread), "async_copy_global_to_shared", + pred, memType, CommitKind::AsyncCp, op); + } + } + + void addReadChecks(ImplicitLocOpBuilder &b, tti::FunctionBuilder &funcBuilder, + Operation *op, Value buf, uint32_t length, Value pred, + MemType memType, int thread, + const std::string &operandName) { + funcBuilder.createVerifyReadVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op); + // commit-num-based synchronization is only supported for shared memory + if (memType == MemType::SHARED_MEM) { + funcBuilder.createCheckOutstandingCommitsCall( + b, buf, length, getBaseThread(thread), "warpgroup_mma operand read", + pred, memType, CommitKind::Wgmma, op); + funcBuilder.createCheckOutstandingCommitsCall( + b, buf, length, getBaseThread(thread), "async_copy_shared_to_global", + pred, memType, CommitKind::TmaStore, op); + } + } + + std::optional getMemEffectsOpInfo(Operation *op) { + std::optional info; + if (auto expectOp = dyn_cast(op)) { + // TODO: For async TMA barriers, the barrier "arrive" corresponding to the + // completion mechanism is modeled by barrier_expect. Individual + // async_tma_copy ops should not decrement the barrier state, otherwise + // multiple copies using the same barrier would incorrectly advance the + // phase multiple times. This should be improved bu tracking the barrier + // expected byte count, and "arriving" the barrier when the expected byte + // count is reached. + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = expectOp.getPred(); + info->barriers.push_back({expectOp.getAlloc(), nullptr, /*count=*/1}); + } + if (auto copyOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = copyOp.getPred(); + // Only track visible accesses against the barrier; do not update the + // barrier state here (see BarrierExpectOp handling above). + info->barriers.push_back({copyOp.getBarrier(), nullptr, /*count=*/0}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + copyOp.getResult()); + } + if (auto storeOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; + info->commitKind = CommitKind::TmaStore; + info->implicitCommit = true; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + storeOp.getSrc()); + } + if (auto gatherOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = gatherOp.getPred(); + // Only track visible accesses against the barrier; do not update the + // barrier state here (see BarrierExpectOp handling above). + info->barriers.push_back({gatherOp.getBarrier(), nullptr, /*count=*/0}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + gatherOp.getResult()); + } + if (auto scatterOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::None; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + scatterOp.getSrc()); + } + if (auto copyOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; + info->commitKind = CommitKind::AsyncCp; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + copyOp.getResult()); + } + if (auto loadOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadOp.getSrc()); + } + if (auto storeOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); + } + if (auto allocOp = dyn_cast(op)) { + if (allocOp.getSrc()) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + allocOp.getResult()); + } + } + if (auto loadOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadOp.getSrc()); + } + if (auto storeOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + storeOp.getDst()); + } + if (auto allocOp = dyn_cast(op)) { + if (allocOp.getSrc()) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + allocOp.getResult()); + } + } + if (auto mmav5Op = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = mmav5Op.getPredicate(); + for (auto [barrier, barrierPred] : + llvm::zip(mmav5Op.getCompletionBarriers(), + mmav5Op.getCompletionBarrierPreds())) { + info->barriers.push_back({barrier, barrierPred, 1}); + } + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + mmav5Op.getA(), "A"); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + mmav5Op.getB(), "B"); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + mmav5Op.getAccumulator(), "Acc"); + } + if (auto commitOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = commitOp.getPred(); + info->barriers.push_back({commitOp.getBarrier(), nullptr, 1}); + } + if (auto arriveOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->pred = arriveOp.getPred(); + info->barriers.push_back( + {arriveOp.getAlloc(), nullptr, (int)arriveOp.getCount()}); + } + if (auto wgmmaOp = dyn_cast(op)) { + if (wgmmaOp.getIsAsync() == true) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; + info->commitKind = CommitKind::Wgmma; + info->implicitCommit = true; + info->barriers = {}; + if (isa( + wgmmaOp.getA().getType().getEncoding())) { + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + wgmmaOp.getA(), "A"); + } + if (isa( + wgmmaOp.getB().getType().getEncoding())) { + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + wgmmaOp.getB(), "B"); + } + } + } + return info; + } + + ModuleOp module; + AuxDataMap auxData; +}; + +} // namespace instrument +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..c7a6bfa557 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonNvidiaGPUIR + Dialect.cpp + TensorMemoryUtils.cpp + Ops.cpp + + DEPENDS + TritonNvidiaGPUTableGen + TritonNvidiaGPUAttrDefsIncGen + TritonNvidiaGPUOpInterfacesIncGen + TritonNvidiaGPUTypesIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp new file mode 100644 index 0000000000..7664c043a5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -0,0 +1,573 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +static constexpr int numTmemRows = 128; + +TMemAllocation getTmemAllocSizes(MemDescType memDescType) { + auto *ctx = memDescType.getContext(); + auto S = [&](StringRef str) { return StringAttr::get(ctx, str); }; + auto kRow = S("row"); + auto kCol = S("col"); + // Remove multibuffering if present + auto shape = memDescType.getShape().take_back(2); + auto ll = toLinearLayout(shape, memDescType.getEncoding()); + auto bitwidth = memDescType.getElementTypeBitWidth(); + int nRow = ll.getInDimSize(kRow); + int nCol = ll.getInDimSize(kCol) / (32 / bitwidth); + // If we have just one 16xcol block per warp, we don't allocate 128 rows + // we use 64 rows instead. + // We could generalise this to when we have more zeros in the layout, but + // the allocator does not support this yet + if (ll.getBasis(kRow, llvm::Log2_32(16)) == ArrayRef{0, 0}) { + nRow /= 2; + } + + // Hack: We should represent this in the LL. Remove the block dimension + if (auto tmemEnc = + dyn_cast(memDescType.getEncoding())) { + nCol /= tmemEnc.getCTASplitM() * tmemEnc.getCTASplitN(); + } else if (auto tmemScaleEnc = dyn_cast( + memDescType.getEncoding())) { + nCol /= tmemScaleEnc.getCTASplitM() * tmemScaleEnc.getCTASplitN(); + } + // If multibuffering is present, we need to allocate more cols + if (memDescType.getRank() > 2) { + assert(memDescType.getRank() == 3); + nCol *= memDescType.getDimSize(0); + } + return {nRow, nCol}; +} + +LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, bool unpacked, + bool withWarp) { + auto str_attr = [&](StringRef str) { return StringAttr::get(ctx, str); }; + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kRow = str_attr("row"); + auto kCol = str_attr("col"); + // Set the output order to be kRow, kCol and the input order to be kReg first + LinearLayout tile = LinearLayout({{kReg, {}}, {kLane, {}}}, {kRow, kCol}); + // Each register moves 32/bitwidth (= 2) columns when unpacked + if (unpacked) { + tile *= LinearLayout::zeros1D(1, kReg, kCol, 2); + } + if (atom == TMemAccessAtom::I32x32b) { + tile *= LinearLayout::identity1D(32, kLane, kRow); + } else if (atom == TMemAccessAtom::I16x32bx2) { + tile *= LinearLayout::identity1D(16, kLane, kRow); + } else if (atom == TMemAccessAtom::I16x64b) { + LinearLayout::BasesT bases; + bases[kLane] = std::vector>{ + {8, 0}, {0, 1}, {1, 0}, {2, 0}, {4, 0}}; + tile *= LinearLayout(std::move(bases), {kRow, kCol}); + } else if (atom == TMemAccessAtom::I16x128b) { + tile *= LinearLayout::identity1D(4, kLane, kCol) * + LinearLayout::identity1D(8, kLane, kRow) * + LinearLayout::identity1D(2, kReg, kRow); + } else if (atom == TMemAccessAtom::I16x256b) { + tile *= LinearLayout::identity1D(2, kReg, kCol) * + LinearLayout::identity1D(4, kLane, kCol) * + LinearLayout::identity1D(8, kLane, kRow) * + LinearLayout::identity1D(2, kReg, kRow); + } else { + llvm_unreachable("Unsupported TMEM access atom"); + } + if (withWarp) { + auto nCol = tile.getOutDimSize(kCol); + auto bases = tile.getBases(); + bases[kWarp].push_back({32, 0}); + bases[kWarp].push_back({64, 0}); + tile = LinearLayout(std::move(bases), {{kRow, 128}, {kCol, nCol}}, false); + } + return tile; +} + +static std::optional getDistributedLayoutForTmemLdSt( + const LinearLayout &ll, TMemAccessAtom atom, unsigned numWarps, + int bitwidth, + std::optional cgaLayout = std::nullopt) { + auto dims = to_vector(ll.getOutDimNames()); + assert(dims.size() == 2); + auto rowColDims = to_vector(ll.getInDimNames()); + auto *ctx = dims[0].getContext(); + // Add block dimension + if (cgaLayout) { + // Get CGALayout without broadcasting to divide the ll + // as the TMEM layout does not reflect CTA broadcasting + auto cgaShape = to_vector(cgaLayout->getLinearLayout().getOutDimSizes()); + auto kBlock = StringAttr::get(ctx, "block"); + // The cta order in TMEM is always [0, 1] + auto ctaCol = + LinearLayout::identity1D(cgaShape[0], rowColDims[1], dims[0]) * + LinearLayout::identity1D(cgaShape[1], rowColDims[1], dims[1]); + auto quot = divideRight(ll, ctaCol); + bool isM64TwoCTA = !quot.has_value(); + if (isM64TwoCTA) { + auto bases = ll.getBases(); + auto logNCols = ll.getInDimSizeLog2(rowColDims[1]); + auto numCTAs = ctaCol.getTotalOutDimSize(); + auto basisCTA1 = logNCols - 1 - llvm::Log2_32(numCTAs * numWarps / 4); + // Swap the (soon to be) warp=2 and block=1 bases + std::swap(bases[rowColDims[0]].back(), bases[rowColDims[1]][basisCTA1]); + auto transposedLL = + LinearLayout(std::move(bases), ll.getOutDims(), ll.isSurjective()); + auto ctaCol = + LinearLayout::identity1D(cgaShape[0] / 2, rowColDims[1], dims[0]) * + LinearLayout::identity1D(cgaShape[1], rowColDims[1], dims[1]); + quot = divideRight(transposedLL, ctaCol); + assert(quot.has_value()); + } + auto maybeRet = + getDistributedLayoutForTmemLdSt(*quot, atom, numWarps, bitwidth); + if (!maybeRet) + return maybeRet; + // Add the full block layout (with broadcasting) + if (isM64TwoCTA) { + auto bases = maybeRet->getBases(); + // Last reg has block[0] basis + // This is correct as we don't currently support emitting + // more than 1 tcgen05.mma instruction per N dimension + auto kReg = StringAttr::get(ctx, "register"); + bases[kBlock].push_back(bases[kReg].back()); + bases[kReg].pop_back(); + auto kWarp = StringAttr::get(ctx, "warp"); + std::swap(bases[kWarp][1], bases[kBlock][0]); + auto ret = LinearLayout(std::move(bases), maybeRet->getOutDims(), + maybeRet->isSurjective()); + // Remove first block basis as it's already in the layout + auto cta1 = LinearLayout::identity1D(2, kBlock, dims[0]); + auto smallCgaLayout = divideLeft(cgaLayout->getLinearLayout(), cta1); + assert(smallCgaLayout.has_value()); + ret *= smallCgaLayout.value(); + return ret; + } else { + return *maybeRet * cgaLayout->getLinearLayout(); + } + } + // This code is dual to the one in lowerTMemLdSt + if (bitwidth != 32) { + // TODO move this to a helper function + auto kReg = StringAttr::get(ctx, "register"); + LinearLayout quot; + int bestContig = 1; + for (int contig = 1; bitwidth * contig <= 32; contig *= 2) { + auto maybeQuot = divideLeft( + ll, LinearLayout::identity1D(contig, rowColDims[1], dims[1])); + if (!maybeQuot) + break; + quot = *maybeQuot; + bestContig = contig; + } + + // Pack contiguous elements + // This works to pack b8 or b16 into b32 but also b8 into b16 and recurse + if (bestContig > 1) { + auto ret = getDistributedLayoutForTmemLdSt(quot, atom, numWarps, + bitwidth * bestContig); + if (!ret) + return ret; + auto castbbitwidth = LinearLayout::identity1D(bestContig, kReg, dims[1]); + return castbbitwidth * ret.value(); + } + if (auto maybeQuot = divideLeft( + ll, LinearLayout::zeros1D(32 / bitwidth, rowColDims[1], dims[1]) * + LinearLayout::identity1D(2, rowColDims[1], dims[1])); + bitwidth == 16 && maybeQuot) { + // Unpacked case + auto ret = + getDistributedLayoutForTmemLdSt(*maybeQuot, atom, numWarps, 32); + if (!ret) + return ret; + auto castbbitwidth = LinearLayout::identity1D(2, kReg, dims[1]); + return castbbitwidth * ret.value(); + } else if (auto maybeQuot = + divideLeft(ll, LinearLayout::zeros1D( + 32 / bitwidth, rowColDims[1], dims[1]))) { + // Software padding + assert(maybeQuot); + return getDistributedLayoutForTmemLdSt(*maybeQuot, atom, numWarps, 32); + } else if (ll.getInDimSize(rowColDims[1]) == 1) { + // Software padding with just one column + return getDistributedLayoutForTmemLdSt(ll, atom, numWarps, 32); + } else { + assert(false && "Should not happen"); + } + } + // getTileLayout returns the layout for a bitwidth of 32 + assert(bitwidth == 32); + auto tile = getTileLayout(ctx, atom, false, /*withWarp=*/false); + // Plan: + // tile: register, lane -> row, cols + // ll: row, cols -> dim0, dim1 + // We extend the tile to have the right vectorisation + warps and + // the result is given by + // ll o tile : register, lane, warp -> dim0, dim1 + + auto nColsTile = tile.getOutDimSize(rowColDims[1]); + auto nColsLL = ll.getInDimSize(rowColDims[1]); + auto nColsMissing = nColsLL / nColsTile; + if (nColsMissing == 0) { + return std::nullopt; + } + auto kReg = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + bool instr32Rows = atom == TMemAccessAtom::I32x32b; + bool layout16Rows = + ll.getBasis(rowColDims[0], llvm::Log2_32(16)) == ArrayRef{0, 0}; + + // We are choosing the distributed layout (ll o tile). In the lowering + // we will do ll^{-1} o (ll o tile) and we expect to get tile back. + // For this to be possible, ll should accept a left-inverse, that is, it + // should be injective + // In less fancy words, we look for the `comp` layout not to have any zero + // basis as that would disallow the resulting layout to be left-divisible by + // the tile + auto comp = + tile.compose(ll).sublayout({kReg, kLane}, to_vector(ll.getOutDimNames())); + if (instr32Rows) { + // We will use 16x32bx2 instruction for lane=16 so we remove the last lane + // basis + comp = comp.resizeInDim(kLane, comp.getInDimSize(kLane) / 2); + } + if (!comp.isInjective()) + return std::nullopt; + + // Fit the warp bases either tiling on the RHS or in row=16 + StringAttr row16; + // If we need to fit something (the instruction does not cover it + // and the layout has 32 rows) we first try to fit a warp, and if we + // can't we fit a register + if (!instr32Rows && !layout16Rows) { + if (numWarps > 4) { + row16 = kWarp; + } else { + row16 = kReg; + } + } + + // We reserve enough columns to fit in the warps + int warpsToTile = numWarps / ((row16 == kWarp) ? 8 : 4); + // Cap warps to tile above by nColsMissing. The rest go to broadcasting + int warpBroadcast = warpsToTile / std::min(nColsMissing, warpsToTile); + warpsToTile /= warpBroadcast; + nColsMissing /= warpsToTile; + + if (nColsMissing > 1) { + if (instr32Rows && layout16Rows) { + // If the lane 16 would load repeated data, instead we make it load half + // of the data via the 16x32bx2 instruction + tile = divideLeft(tile, LinearLayout::identity1D(2, kLane, rowColDims[0])) + .value(); + tile *= LinearLayout::identity1D(nColsMissing / 2, kReg, rowColDims[1]) * + LinearLayout::identity1D(2, kLane, rowColDims[1]); + + } else { + tile *= LinearLayout::identity1D(nColsMissing, kReg, rowColDims[1]); + } + } + + // add the warp bases. The M=64 + 2CTA case has already been handled + auto bases = tile.getBases(); + auto &warpBases = bases[kWarp]; + warpBases.push_back({32, 0}); + warpBases.push_back({64, 0}); + + if (row16) { + bases[row16].push_back({16, 0}); + } + tile = LinearLayout(std::move(bases), + {{rowColDims[0], 128}, + {rowColDims[1], tile.getOutDimSize(rowColDims[1])}}, + false); + tile *= LinearLayout::identity1D(warpsToTile, kWarp, rowColDims[1]); + tile *= LinearLayout::zeros1D(warpBroadcast, kWarp, rowColDims[1]); + assert(tile.getOutDimSize(rowColDims[1]) == ll.getInDimSize(rowColDims[1])); + + auto ret = tile.compose(ll); + return ret; +} + +std::optional +getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, + unsigned numWarps, + gpu::CGAEncodingAttr cgaLayout) { + assert(memType.getMemorySpace() == + TensorMemorySpaceAttr::get(memType.getContext())); + assert(numWarps >= 4 && llvm::isPowerOf2_32(numWarps) && + "numWarps must be a power of 2 and >= 4"); + assert(atom != TMemAccessAtom::I16x32bx2 && + "This layout is inferred sometimes for the 32x32b atom"); + auto ll = toLinearLayout(memType.getShape(), memType.getEncoding()); + auto bitwidth = memType.getElementTypeBitWidth(); + return getDistributedLayoutForTmemLdSt(ll, atom, numWarps, bitwidth, + cgaLayout); +} + +DistributedEncodingTrait +getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, + gpu::CGAEncodingAttr cgaLayout) { + auto *ctx = memType.getContext(); + bool prefer16x256 = + triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT"); + if (prefer16x256) { + auto layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I16x256b, numWarps, cgaLayout); + if (layout) { + return LinearEncodingAttr::get(ctx, std::move(*layout)); + } + } + auto layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I32x32b, numWarps, cgaLayout); + assert(layout); + return LinearEncodingAttr::get(ctx, std::move(*layout)); +} + +std::optional +getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, MemDescType memType, + int numWarps) { + if (numWarps != 8) + return std::nullopt; + + auto cgaLayout = getCGALayout(tensorType.getEncoding()); + std::optional layout = getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I32x32b, numWarps, cgaLayout); + if (!layout) + return std::nullopt; + auto ret = std::move(*layout); + + // Optimisation for reductions: + // We can map lane=16 to any dimension, and it will be lowered to 32x16bx2. + // As such, if we have 8 warps and the basis warp=4 is mapped to a different + // dimension than warp=1, warp=2, and lane=16 is mapped to the same dimension + // as the first two warp bases, we can swap warp=4 and lane=16. + // Generally, we don't want warp=4 to have data on a different dimension to + // dim=1 and dim=2 + auto *ctx = tensorType.getContext(); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + auto dims = to_vector(ret.getOutDimNames()); + + // In most cases this is going to be dim=0, but the optimization + // also applies for scales where we may be able to have the layout + // replicated across warps + for (int dim : {0, 1}) { + auto w1dim = ret.getBasis(kWarp, 0, dims[dim]) == 0; + auto w2dim = ret.getBasis(kWarp, 1, dims[dim]) == 0; + auto w4dim = ret.getBasis(kWarp, 2, dims[dim]) == 0; + auto l16dim = ret.getBasis(kLane, 4, dims[dim]) == 0; + if (l16dim != w4dim && w1dim == w2dim && w1dim == l16dim) { + auto bases = ret.getBases(); + std::swap(bases[kWarp][2], bases[kLane][4]); + return LinearEncodingAttr::get( + tensorType.getContext(), + LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective())); + } + } + return std::nullopt; +} + +SmallVector +getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, + MemDescType memType) { + int numWarps = lookupNumWarps(op); + assert(numWarps % 4 == 0); + auto cgaLayout = getCGALayout(tensorType.getEncoding()); + SmallVector layouts; + for (auto atom : {TMemAccessAtom::I32x32b, TMemAccessAtom::I16x256b, + TMemAccessAtom::I16x128b, TMemAccessAtom::I16x64b}) { + auto ll = + getDistributedLayoutForTmemLdSt(memType, atom, numWarps, cgaLayout); + if (ll) { + layouts.push_back(LinearEncodingAttr::get(tensorType.getContext(), + std::move(ll.value()))); + } + } + // Small hack until we generalise isDistributedLayoutTMemCompatible + auto ll = getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps); + if (ll) { + layouts.push_back(ll.value()); + } + return layouts; +} + +// Verify if the distributed layout can be mapped onto tensor memory. +bool isDistributedLayoutTMemCompatible(Operation *op, + RankedTensorType tensorType, + gpu::MemDescType memType) { + auto maxnreg = getContextualMaxNReg(op); + return succeeded(computeTMemLdStEncodingInfo(tensorType, memType, maxnreg)); +} + +LogicalResult +TensorMemoryEncodingAttr::verify(function_ref emitError, + unsigned blockM, unsigned blockN, + unsigned colStride, unsigned CTASplitM, + unsigned CTASplitN, bool) { + if (!(CTASplitM >= 1 && CTASplitN >= 1 && llvm::isPowerOf2_32(CTASplitM) && + llvm::isPowerOf2_32(CTASplitN))) { + return emitError() + << "CTASplitM and CTASplitN must be greater than 0 and a power of 2"; + } + if (blockM != 64 && blockM != 128) { + return emitError() << "blockM must be 64 or 128 but got " << blockM; + } + if (!llvm::isPowerOf2_32(blockN)) { + return emitError() << "blockN must be a power of 2 but got " << blockN; + } + if (blockN > 512) { + return emitError() << "blockN must be less than or equal to 512 but got " + << blockN; + } + if (!(colStride == 1 || colStride == 2 || colStride == 4)) { + return emitError() << "colStride must be 1, 2, or 4 but got " + << "but got " << colStride; + } + return success(); +} + +LogicalResult impl::verifyMMAv5Op(Operation *op) { + auto isInterleaved = [](MemDescType memdesc) { + auto enc = dyn_cast(memdesc.getEncoding()); + return enc && getTmemAllocSizes(memdesc).numRows != 64 && + enc.getBlockM() == 64; + }; + + auto itf = cast(op); + if (isInterleaved(itf.getA().getType()) && + isInterleaved(itf.getAccumulator().getType())) { + return op->emitOpError( + "does not support blockM=64 with interleaved blocks in TMEM layout"); + } + return success(); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// Type methods +//===----------------------------------------------------------------------===// +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// TensorDescIm2ColType Verifier +//===----------------------------------------------------------------------===// +LogicalResult +TensorDescIm2ColType::verify(function_ref emitError, + RankedTensorType blockType) { + // blockType must be rank 2 for im2col mode + if (blockType.getRank() != 2) { + return emitError() + << "TensorDescIm2ColType requires rank-2 blockType, got rank " + << blockType.getRank(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// +namespace { +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "tmem"; + return AliasResult::FinalAlias; + } + if (mlir::isa(attr)) { + os << "tmem_scales"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// + +void TritonNvidiaGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); +} + +// verify TritonNvidiaGPU ops +LogicalResult +TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp new file mode 100644 index 0000000000..1e7fcea3bc --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -0,0 +1,1160 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// -- WarpGroupDotOp -- +LogicalResult WarpGroupDotOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult WarpGroupDotOp::verify() { + auto resTy = getD().getType(); + auto nvmmaEnc = dyn_cast(resTy.getEncoding()); + if (!nvmmaEnc || !nvmmaEnc.isHopper()) + return emitOpError("WGMMA result layout must be Hopper NVMMA"); + + if (!isa(getA().getType().getEncoding())) + return emitOpError("WGMMA A operand must have NVMMA shared or dot layout"); + if (!isa( + getB().getType().getEncoding())) + return emitOpError("WGMMA B operand must have NVMMA shared layout"); + + auto numWarps = gpu::lookupNumWarps(getOperation()); + if (numWarps % 4) + return emitOpError("WGMMA requires num_warps to be divisible by 4"); + + auto retShapePerCTA = getShapePerCTA(resTy); + int rank = retShapePerCTA.size(); + if (rank != 2) + return emitOpError("WGMMA result shape must be 2D"); + if (retShapePerCTA[0] % 64 != 0) + return emitOpError("WGMMA result M dimension must be divisible by 64"); + if (retShapePerCTA[1] % 8 != 0) + return emitOpError("WGMMA result N dimension must be divisible by 8"); + + // Verify MMA version is supported for operands. + int mmaVersion = nvmmaEnc.getVersionMajor(); + if (!supportMMA(getA(), mmaVersion) || !supportMMA(getB(), mmaVersion)) + return emitOpError("unsupported MMA version for the given operands"); + + auto aElemTy = getA().getType().getElementType(); + if (getMaxNumImpreciseAcc() < 32 && + (llvm::isa(aElemTy)) && + resTy.getElementType().isF32()) { + return emitOpError("Cannot use F32 as the accumulator element type when " + "the max_num_imprecise_acc is less than 32"); + } + + if (auto aTensorTy = dyn_cast(getA().getType())) { + auto aDotOpEnc = cast(aTensorTy.getEncoding()); + unsigned kWidth = 32 / aTensorTy.getElementTypeBitWidth(); + if (aDotOpEnc.getKWidth() != kWidth) { + return emitOpError("in-register LHS operand must have a kWidth of ") + << kWidth << " but got " << aDotOpEnc.getKWidth(); + } + } + + return success(); +} + +void WarpGroupDotOp::getEffects( + SmallVectorImpl> + &effects) { + auto &a = getAMutable(); + auto &b = getBMutable(); + if (isa(a.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &a, SharedMemory::get()); + if (isa(b.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &b, SharedMemory::get()); +} + +bool WarpGroupDotOp::needsPartialAccumulator() { + const auto &a = getA(); + const auto &d = getD(); + auto aTensorTy = cast(a.getType()); + auto aElTy = cast(a.getType()).getElementType(); + bool isFP8 = llvm::isa(aElTy); + bool accFP32 = + cast(d.getType()).getElementType().isF32(); + uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); + return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; +} + +bool WarpGroupDotOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +// -- WarpGroupDotWaitOp -- +LogicalResult WarpGroupDotWaitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return success(); +} + +LogicalResult WarpGroupDotWaitOp::verify() { + if (getOperands().empty()) + return emitOpError("expected to be waiting on at least one dependency"); + return success(); +} + +// -- InitBarrierOp -- +LogicalResult InitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +// -- InvalBarrierOp -- +LogicalResult InvalBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +// -- BarrierExpectOp -- +LogicalResult BarrierExpectOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +// -- WaitBarrierOp -- +LogicalResult WaitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +// -- ArriveBarrierOp -- +LogicalResult ArriveBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + if (getCount() < 1) + return emitOpError("count must be greater than or equal to 1"); + return success(); +} + +// -- FenceMBarrierInitReleaseClusterOp -- +LogicalResult FenceMBarrierInitReleaseClusterOp::verify() { + int numCTAs = triton::gpu::lookupNumCTAs(getOperation()); + if (numCTAs <= 1) + return emitOpError("requires ttg.num-ctas > 1"); + return success(); +} + +// -- ClusterArriveOp -- +LogicalResult ClusterArriveOp::verify() { + int numCTAs = triton::gpu::lookupNumCTAs(getOperation()); + if (numCTAs <= 1) + return emitOpError("requires ttg.num-ctas > 1"); + return success(); +} + +// -- ClusterWaitOp -- +LogicalResult ClusterWaitOp::verify() { + int numCTAs = triton::gpu::lookupNumCTAs(getOperation()); + if (numCTAs <= 1) + return emitOpError("requires ttg.num-ctas > 1"); + return success(); +} + +// -- TMA operation verifiers -- +static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc, + Attribute enc) { + auto nvmma = dyn_cast(enc); + if (!nvmma) + return op->emitOpError("TMA descriptor must have NVMMA shared layout"); + auto descEnc = dyn_cast_if_present( + desc.getBlockType().getEncoding()); + // NOTE: Cannot do descEnc != enc as the encodings may differ in rank for + // rank-reducing loads + if (!descEnc || descEnc.getTransposed() != nvmma.getTransposed() || + descEnc.getSwizzlingByteWidth() != nvmma.getSwizzlingByteWidth() || + descEnc.getElementBitWidth() != nvmma.getElementBitWidth() || + descEnc.getFp4Padded() != nvmma.getFp4Padded()) { + return op->emitOpError("TMA descriptor layout must match shared layout, " + "but got descriptor layout ") + << descEnc << " and shared memory layout " << nvmma; + } + if (nvmma.getTransposed()) + return op->emitOpError("TMA descriptor layout must not be transposed"); + return success(); +} + +static LogicalResult verifyAsyncTMALoadOp(Operation *op, + TensorDescInterface desc, + TypedValue barrier, + MemDescType resultType) { + if (failed(verifyBarrierType(op, barrier.getType()))) + return failure(); + if (!resultType.getMutableMemory()) + return op->emitOpError("cannot store into immutable memory"); + if (failed(verifyTMAEncoding(op, desc, resultType.getEncoding()))) + return failure(); + return success(); +} + +static LogicalResult verifyAsyncTMAStoreOp(Operation *op, + TypedValue desc, + MemDescType srcType) { + Attribute srcEnc = srcType.getEncoding(); + // `cp.async.bulk.tensor` to global memory and `cp.reduce.async.bulk.tensor` + // do not support fp4_padded operands. + if (isFp4Padded(srcEnc)) + return op->emitOpError("does not support fp4_padded operands"); + return verifyTMAEncoding(op, desc.getType(), srcEnc); +} + +// Helper to determine if the descriptor type is for im2col mode +static bool isIm2ColDescriptor(Type descType) { + return isa(descType); +} + +static LogicalResult verifyAsyncTMACoords(Operation *op, ValueRange coords, + TensorDescInterface desc, + bool isIm2Col) { + unsigned blockRank = desc.getBlockType().getRank(); + + if (isIm2Col) { + // For IM2COL mode, coordinates are for the full tensor (3D-5D) + // not the 2D block shape + if (coords.size() < 3) + return op->emitOpError( + "IM2COL mode requires at least 3D coordinates, but got ") + << coords.size() << "D"; + if (coords.size() > 5) + return op->emitOpError( + "IM2COL mode supports at most 5D coordinates, but got ") + << coords.size() << "D"; + } else { + // For TILED mode, coordinates must match the block rank + if (coords.size() != blockRank) { + return op->emitOpError("expected ") + << blockRank << " coordinates, but got " << coords.size(); + } + if (coords.size() < 1 || coords.size() > 5) + return op->emitOpError("must have between 1 and 5 coordinates"); + } + return success(); +} + +static LogicalResult verifyTMAMode(Operation *op, bool isIm2Col, + ValueRange coords, ValueRange offsets) { + if (isIm2Col) { + if (offsets.empty()) + return op->emitOpError("IM2COL mode requires offsets to be provided"); + + // For IM2COL mode, the number of offsets should be coord.size() - 2 + // 4D tensors (4 coords) need 2 offsets, 5D tensors (5 coords) need 3 + // offsets + size_t expectedOffsets = coords.size() - 2; + if (offsets.size() != expectedOffsets) { + return op->emitOpError("IM2COL mode with ") + << coords.size() << "D coordinates requires " << expectedOffsets + << " offsets, but got " << offsets.size(); + } + } else { + // TILED mode should not have offsets + if (!offsets.empty()) + return op->emitOpError("TILED mode does not support offsets"); + } + return success(); +} + +// -- AsyncTMACopyGlobalToLocalOp -- +LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { + auto descType = getDesc().getType(); + bool isIm2Col = isIm2ColDescriptor(descType); + auto descInterface = cast(descType); + + if (failed(verifyAsyncTMACoords(*this, getCoord(), descInterface, isIm2Col))) + return failure(); + auto resultType = getResult().getType(); + if (failed(verifyDescriptorLoadStoreOp(*this, descType, resultType))) + return failure(); + if (failed(verifyAsyncTMALoadOp(*this, descInterface, getBarrier(), + getResult().getType()))) + return failure(); + if (failed(verifyTMAMode(*this, isIm2Col, getCoord(), getOffsets()))) + return failure(); + return success(); +} + +// -- AsyncTMACopyLocalToGlobalOp -- +LogicalResult AsyncTMACopyLocalToGlobalOp::verify() { + // Store ops only support TILED mode + if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(), + /*isIm2Col=*/false))) + return failure(); + MemDescType srcType = getSrc().getType(); + if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType))) + return failure(); + return verifyAsyncTMAStoreOp(*this, getDesc(), srcType); +} + +// -- AsyncTMAReduceOp -- +LogicalResult AsyncTMAReduceOp::verify() { + // Reduce ops only support TILED mode + if (failed(verifyAsyncTMACoords(*this, getCoord(), getDesc().getType(), + /*isIm2Col=*/false))) + return failure(); + MemDescType srcType = getSrc().getType(); + if (failed(verifyDescriptorLoadStoreOp(*this, getDesc().getType(), srcType))) + return failure(); + return verifyAsyncTMAStoreOp(*this, getDesc(), srcType); +} + +// -- AsyncTMAGatherOp -- +LogicalResult AsyncTMAGatherOp::verify() { + auto resultType = getResult().getType(); + if (failed(verifyAsyncTMALoadOp(*this, getDesc().getType(), getBarrier(), + resultType))) + return failure(); + // `tile::gather4` does not support fp4_padded operands. + if (isFp4Padded(getResult().getType().getEncoding())) + return emitOpError("does not support fp4_padded operands"); + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + resultType, getXOffsets().getType()); +} + +// -- AsyncTMAScatter -- +LogicalResult AsyncTMAScatterOp::verify() { + auto srcType = getSrc().getType(); + if (failed(verifyAsyncTMAStoreOp(*this, getDesc(), srcType))) + return failure(); + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + srcType, getXOffsets().getType()); +} + +// -- TCGen5MMAOp -- + +// barrier-and-pred := `,` ssa-value `[` ssa-value `]` +// barriers-and-preds := (barrier-and-pred)* +static ParseResult +parseBarriersAndPreds(OpAsmParser &p, + SmallVectorImpl &barriers, + SmallVectorImpl &preds) { + while (succeeded(p.parseOptionalComma())) { + if (p.parseOperand(barriers.emplace_back()) || p.parseLSquare() || + p.parseOperand(preds.emplace_back()) || p.parseRSquare()) + return failure(); + } + return success(); +} +static void printBarriersAndPreds(OpAsmPrinter &p, Operation *op, + OperandRange barriers, OperandRange preds) { + assert(barriers.size() == preds.size()); + for (auto [barrier, pred] : llvm::zip(barriers, preds)) { + p << ", " << barrier << '[' << pred << ']'; + } +} + +// token := `[` (ssa-value (`,` ssa-value)*)? `]` +// dep-operand := token? +static ParseResult +parseToken(OpAsmParser &p, std::optional &dep, + Type &token) { + if (failed(p.parseOptionalLSquare())) + return success(); + token = p.getBuilder().getType(); + if (succeeded(p.parseOptionalRSquare())) + return success(); + if (p.parseOperand(dep.emplace()) || p.parseRSquare()) + return failure(); + return success(); +} +static void printToken(OpAsmPrinter &p, Operation *op, Value dep, Type token) { + if (!token) + return; + p << '['; + if (dep) + p << dep; + p << ']'; +} + +namespace { +enum class MMADTypeKind { tf32, f16, f8f6f4, i8 }; +} // namespace + +static std::string strMMADTypeKind(MMADTypeKind kind) { + switch (kind) { + case MMADTypeKind::tf32: + return "tf32"; + case MMADTypeKind::f16: + return "f16"; + case MMADTypeKind::f8f6f4: + return "f8f6f4"; + case MMADTypeKind::i8: + return "i8"; + } + llvm_unreachable("unknown mma dtype kind"); +} + +static std::optional>> +getMMAv5DTypeKindAndAcc(Type t) { + MLIRContext *ctx = t.getContext(); + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-kind-shapes + if (t.isF32()) { + return {{MMADTypeKind::tf32, {Float32Type::get(ctx)}}}; + } + if (t.isF16()) { + return { + {MMADTypeKind::f16, {Float16Type::get(ctx), Float32Type::get(ctx)}}}; + } + if (t.isBF16()) { + return {{MMADTypeKind::f16, {Float32Type::get(ctx)}}}; + } + // TODO: float6 and explicit float4 types are not supported yet. + // TODO: tcgen05.mma supports ui8/si8 -> s32 MMA, but Triton does not. + // FIXME: i8 is used to represent float4 types. + if (isa(t) || t.isInteger(8)) { + return { + {MMADTypeKind::f8f6f4, {Float16Type::get(ctx), Float32Type::get(ctx)}}}; + } + return std::nullopt; +} + +static LogicalResult verifyMMADType(Operation *op, Type a, Type b, Type d) { + auto akind = getMMAv5DTypeKindAndAcc(a); + auto bkind = getMMAv5DTypeKindAndAcc(b); + if (!akind) + return op->emitOpError("unsupported LHS operand dtype: ") << a; + if (!bkind) + return op->emitOpError("unsupported RHS operand dtype: ") << b; + if (akind->first != bkind->first) { + return op->emitOpError( + "LHS and RHS operand dtypes kinds don't match: LHS kind is ") + << strMMADTypeKind(akind->first) << " but RHS kind is " + << strMMADTypeKind(bkind->first); + } + if (!llvm::is_contained(akind->second, d) || + !llvm::is_contained(bkind->second, d)) { + InFlightDiagnostic diag = + op->emitOpError("unsupported accumulator dtype for operand types ") + << a << " and " << b << ", accumulator dtype is " << d + << " but must be one of ["; + llvm::interleaveComma(akind->second, diag, [&](Type t) { diag << t; }); + diag << "]"; + return diag; + } + return success(); +} + +LogicalResult TCGen5MMAOp::verify() { + if (!getIsAsync() && !getBarriers().empty()) { + return emitOpError("The op is synchronous but a barrier is present."); + } + Type atype = getA().getType().getElementType(); + Type btype = getB().getType().getElementType(); + Type dtype = getD().getType().getElementType(); + if (failed(verifyMMADType(*this, atype, btype, dtype))) + return failure(); + + auto aEnc = getA().getType().getEncoding(); + if (!isa(aEnc)) + return emitOpError( + "LHS operand must have a NVMMAShared or TensorMemory encoding"); + auto bEnc = getB().getType().getEncoding(); + if (!isa(bEnc)) + return emitOpError("RHS operand must have a NVMMAShared encoding"); + auto retType = getD().getType(); + auto retEnc = dyn_cast(retType.getEncoding()); + if (!retEnc) + return emitOpError("Return operand must have a TensorMemory encoding"); + + // Check colStride of TMEM operands + if (auto tmem = dyn_cast(aEnc)) { + if (tmem.getColStride() != 1) + return emitOpError("The col stride of the LHS operand must be 1"); + } + if (retEnc.getColStride() != 32 / retType.getElementTypeBitWidth()) + return emitOpError("The col stride of the return operand must be 32 / ") + << retType.getElementTypeBitWidth() << " but got " + << retEnc.getColStride(); + // The maximum size of a MMA instruction is 128x256 + if (retEnc.getBlockN() > 256) + return emitOpError("The block size of the return operand must be less than " + "or equal to 256"); + + auto aSplit = getCTASplitNum(aEnc); + auto bSplit = getCTASplitNum(bEnc); + if (aSplit[1] != 1) { + return emitOpError("LHS CTASplit along K should be 1, but got ") + << aSplit[1]; + } + if (bSplit[0] != 1) { + return emitOpError("RHS CTASplit along K should be 1, but got ") + << bSplit[0]; + } + + if (getTwoCtas()) { + auto retSplit = getCTASplitNum(retEnc); + + auto nPerCTA = retType.getDimSize(1) / retSplit[1]; + + // [Note: numRepN > 1 and two_ctas] + // Consider, just as an example, num_ctas=16, and a huge tile of shape + // MNK = 512x64x2048 + // This is an example of layout with numRepN=2 and two_ctas=true: + // Layout RHS: + // #ttg.memdesc<64x2048xf16, + // #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, + // elementBitWidth = 16, + // CGALayout = [[0, 1], [0, 2], [0, 4], [0, 0]]}>> + // + // As a LinearLayout: + // offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1], [8, 2], + // [16, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [32, 0]] + // block = [[0, 256], [0, 512], [0, 1024], [0, 0]] + // + // The issue is that the data from the CTA1 should be next to that of the + // first part of the instruction. Now, the max instruction size is 128x256, + // so the layout we should use is + // offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1], [8, 2], + // [16, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 256], [32, 0]] + // block = [[0, 128], [0, 512], [0, 1024], [0, 0]] + // (note how we swapped the bases [0, 256] and [0, 128]) + // The issue with this layout is that it breaks the invariant that the + // CGALayout splits the CGA tile into contiguous CTA tiles, + // i.e. total_layout = cta_layout * cga_layout. + // This is used all over the place, to the point that for all legacy layouts + // we represent the CGALayout as the `cga_layout` we have to multiply on the + // right. + // We could allow with a bit of effort SharedLinearLayouts that did not + // divide on the right by a CGALayout, but for now we throw a lovely error. + if (nPerCTA > 256) + return emitOpError( + "We don't allow to emit more than one mma instruction along N. " + "Reduce the block or increase the number of warps or CTAs along N"); + + unsigned retM = retSplit[0]; + unsigned retN = retSplit[1]; + if (aSplit[0] != retM) { + return emitOpError("twoCTA mode expects the LHS split along M to match " + "the result split along M. Expected ") + << retM << " but got " << aSplit[0]; + } + if (bSplit[1] != 2 * retN) { + return emitOpError( + "twoCTA mode expects the RHS split along N to be twice the " + "result split along N. Expected ") + << 2 * retN << " but got " << bSplit[1]; + } + + if (!retEnc.getTwoCTAs()) + return emitOpError( + "The returned value's encoding must have twoCTA=true to be used " + "in a twoCTA matmul"); + if (auto tmemEnc = dyn_cast(aEnc)) { + if (!tmemEnc.getTwoCTAs()) + return emitOpError( + "The LHS operand's encoding must have twoCTA=true to be used " + "in a twoCTA matmul"); + } + } + + return success(); +} + +void TCGen5MMAOp::getEffects( + SmallVectorImpl> + &effects) { + // The op reads the accumulator if `useD` is not known to be false. + APInt useD; + if (!matchPattern(getUseD(), m_ConstantInt(&useD)) || !useD.isZero()) { + effects.emplace_back(MemoryEffects::Read::get(), &getDMutable(), + TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(), + TensorMemory::get()); + + if (isa(getA().getType().getMemorySpace())) { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + SharedMemory::get()); + + } else { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), &getBMutable(), + SharedMemory::get()); +} + +bool TCGen5MMAOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +Value TCGen5MMAOp::useAccumulator() { return getUseD(); } + +void TCGen5MMAOp::setUseAccumulator(Value flag) { + getUseDMutable().assign(flag); +} + +ValueRange TCGen5MMAOp::getCompletionBarriers() { return getBarriers(); } +ValueRange TCGen5MMAOp::getCompletionBarrierPreds() { + return getBarrierPreds(); +} + +void TCGen5MMAOp::addCompletionBarrier(Value barrier, Value pred) { + getBarrierPredsMutable().append(pred); + getBarriersMutable().append(barrier); +} + +TypedValue TCGen5MMAOp::getAccumulator() { return getD(); } + +void TCGen5MMAOp::setAccumulator(Value accum) { getDMutable().assign(accum); } + +Value TCGen5MMAOp::getPredicate() { return getPred(); } + +void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); } + +void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token, + Value a, Value b, Value d, Value accDep, Value useD, + Value pred, bool twoCtas, bool multicast, + ValueRange barriers, ValueRange barrierPreds, + bool isAsync) { + if (!barriers.empty()) { + isAsync = true; + } + build(builder, state, token, a, b, d, accDep, useD, pred, barriers, + barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr(), + twoCtas ? builder.getUnitAttr() : UnitAttr(), + multicast ? builder.getUnitAttr() : UnitAttr()); +} + +bool TCGen5MMAOp::isAsync() { return getIsAsync(); } + +// -- TCGen5CommitOp -- +LogicalResult TCGen5CommitOp::verify() { + auto numDescs = getDescs().size(); + if (numDescs > 2) + return emitOpError("expected 0, 1, or 2 descriptors, got ") << numDescs; + return success(); +} + +// -- TCGen5MMAScaledOp -- +LogicalResult TCGen5MMAScaledOp::verify() { + Type atype = getA().getType().getElementType(); + Type btype = getB().getType().getElementType(); + Type dtype = getD().getType().getElementType(); + if (failed(verifyMMADType(*this, atype, btype, dtype))) + return failure(); + auto enc = dyn_cast(getD().getType().getEncoding()); + if (!enc) { + return emitOpError( + "expected accumulator layout to be a TensorMemoryLayout"); + } + if (enc.getBlockM() != 128) + return emitOpError("only supports instruction shape blockM=128"); + return success(); +} + +void TCGen5MMAScaledOp::getEffects( + SmallVectorImpl> + &effects) { + // The op reads the accumulator if `useD` is not known to be false. + APInt useD; + if (!matchPattern(getUseD(), m_ConstantInt(&useD)) || !useD.isZero()) { + effects.emplace_back(MemoryEffects::Read::get(), &getDMutable(), + TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(), + TensorMemory::get()); + + if (isa(getA().getType().getMemorySpace())) { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + SharedMemory::get()); + + } else { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), &getBMutable(), + SharedMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getAScaleMutable(), + TensorMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getBScaleMutable(), + TensorMemory::get()); +} + +bool TCGen5MMAScaledOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + bool transA = false; + if (auto aSharedLayout = dyn_cast( + getA().getType().getEncoding())) { + transA = aSharedLayout.getTransposed(); + } + bool transB = false; + if (auto bSharedLayout = dyn_cast( + getB().getType().getEncoding())) { + transB = !bSharedLayout.getTransposed(); + } + auto aKdim = aShape[aShape.size() - 1]; + auto bKdim = bShape[aShape.size() - 2]; + if (this->getAType() == ScaleDotElemType::E2M1 && !transA) + aKdim *= 2; + if (this->getBType() == ScaleDotElemType::E2M1 && !transB) + bKdim *= 2; + + return aKdim == bKdim; +} + +bool TCGen5MMAScaledOp::verifyOutputDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + auto cShape = this->getD().getType().getShape(); + auto oMdim = cShape[cShape.size() - 2]; + auto oNdim = cShape[cShape.size() - 1]; + + int aMdim = aShape[aShape.size() - 2]; + int bNdim = bShape[bShape.size() - 1]; + bool transA = false; + if (auto aSharedLayout = dyn_cast( + getA().getType().getEncoding())) { + transA = aSharedLayout.getTransposed(); + } + bool transB = false; + if (auto bSharedLayout = dyn_cast( + getB().getType().getEncoding())) { + transB = !bSharedLayout.getTransposed(); + } + if (this->getAType() == ScaleDotElemType::E2M1 && transA) + aMdim *= 2; + if (this->getBType() == ScaleDotElemType::E2M1 && transB) + bNdim *= 2; + + if (aMdim != oMdim || bNdim != oNdim) + return false; + return true; +} + +Value TCGen5MMAScaledOp::useAccumulator() { return getUseD(); } + +void TCGen5MMAScaledOp::setUseAccumulator(Value flag) { + getUseDMutable().assign(flag); +} + +ValueRange TCGen5MMAScaledOp::getCompletionBarriers() { return getBarriers(); } +ValueRange TCGen5MMAScaledOp::getCompletionBarrierPreds() { + return getBarrierPreds(); +} + +void TCGen5MMAScaledOp::addCompletionBarrier(Value barrier, Value pred) { + getBarrierPredsMutable().append(pred); + getBarriersMutable().append(barrier); +} + +TypedValue TCGen5MMAScaledOp::getAccumulator() { return getD(); } + +void TCGen5MMAScaledOp::setAccumulator(Value accum) { + getDMutable().assign(accum); +} + +Value TCGen5MMAScaledOp::getPredicate() { return getPred(); } + +void TCGen5MMAScaledOp::setPredicate(Value pred) { + getPredMutable().assign(pred); +} + +int64_t TCGen5MMAScaledOp::getBlockM() { + ArrayRef shape = getA().getType().getShape(); + int64_t blockM = shape[shape.size() - 2]; + bool transA = false; + if (auto aSharedLayout = dyn_cast( + getA().getType().getEncoding())) { + transA = aSharedLayout.getTransposed(); + } + if (this->getAType() == ScaleDotElemType::E2M1 && transA) + blockM *= 2; + return blockM; +} + +int64_t TCGen5MMAScaledOp::getBlockN() { + ArrayRef shape = getB().getType().getShape(); + int64_t blockN = shape[shape.size() - 1]; + bool transB = false; + if (auto bSharedLayout = dyn_cast( + getB().getType().getEncoding())) { + transB = !bSharedLayout.getTransposed(); + } + if (this->getBType() == ScaleDotElemType::E2M1 && transB) + blockN *= 2; + return blockN; +} + +int64_t TCGen5MMAScaledOp::getBlockK() { + ArrayRef shape = getA().getType().getShape(); + int64_t blockK = shape[shape.size() - 1]; + bool transA = false; + if (auto aSharedLayout = dyn_cast( + getA().getType().getEncoding())) { + transA = aSharedLayout.getTransposed(); + } + if (this->getAType() == ScaleDotElemType::E2M1 && !transA) + blockK *= 2; + return blockK; +} + +void TCGen5MMAScaledOp::build(OpBuilder &builder, OperationState &state, + Type token, Value a, Value b, Value d, + Value accDep, Value aScale, Value bScale, + ScaleDotElemType aType, ScaleDotElemType bType, + Value useD, Value pred, ValueRange barriers, + ValueRange barrierPreds, bool isAsync) { + MLIRContext *ctx = builder.getContext(); + if (!barriers.empty()) { + isAsync = true; + } + build(builder, state, token, a, b, d, accDep, aScale, bScale, + ScaleDotElemTypeAttr::get(ctx, aType), + ScaleDotElemTypeAttr::get(ctx, bType), useD, pred, barriers, + barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr()); +} + +bool TCGen5MMAScaledOp::isAsync() { return getIsAsync(); } + +// -- TMEMStoreOp -- +static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type, + MemDescType memdesc, StringRef regName) { + if (type.getRank() != 2) + return op->emitOpError(regName) << " must be a 2D tensor"; + if (!type.getEncoding()) + return success(); + + auto maxnreg = getContextualMaxNReg(op); + if (isDistributedLayoutTMemCompatible(op, type, memdesc)) + return success(); + + // If it failed, give the user a hint + SmallVector layouts = + getTmemCompatibleLayouts(op, type, memdesc); + + InFlightDiagnostic diag = op->emitOpError(regName); + diag.attachNote() << "Got: " << type.getEncoding(); + for (Attribute layout : layouts) + diag.attachNote() << "potential TMEM layout: " << layout; + return diag; +} + +LogicalResult TMEMStoreOp::verify() { + if (!isa(getDst().getType().getEncoding())) + return emitOpError("should use tensor memory encoding."); + if (!getDst().getType().getMutableMemory()) { + return emitOpError("Cannot store into an immutable alloc"); + } + if (failed(verifyTMEMOperand(*this, getSrc().getType(), getDst().getType(), + "source"))) + return failure(); + return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), + getDst().getType()); +} + +// -- TMEMLoadOp -- +LogicalResult TMEMLoadOp::verify() { + if (!isa( + getSrc().getType().getMemorySpace())) + return emitOpError("source must be a tensor memory buffer."); + if (!isa( + getSrc().getType().getEncoding())) + return emitOpError("should use tensor memory encoding."); + if (failed(verifyTMEMOperand(*this, getType(), getSrc().getType(), "result"))) + return failure(); + + // Validate reduction-related attributes + auto redOp = getRedOp(); + bool hasRed = getRed() != nullptr; + bool useAbs = getAbs().value_or(false); + bool useNaN = getNaN().value_or(false); + + // redOp and red result must be consistent + if (redOp && !hasRed) + return emitOpError("redOp is set but 'red' result is not present"); + if (hasRed && !redOp) + return emitOpError("'red' result is present but redOp is not set"); + + // abs and NaN require redOp + if (useAbs && !redOp) + return emitOpError("'abs' requires 'redOp' to be set"); + if (useNaN && !redOp) + return emitOpError("'NaN' requires 'redOp' to be set"); + + // abs and NaN require floating-point element type + Type elemTy = getSrc().getType().getElementType(); + if (useAbs && !elemTy.isF32()) + return emitOpError("'abs' requires floating-point element type (f32)"); + if (useNaN && !elemTy.isF32()) + return emitOpError("'NaN' requires floating-point element type (f32)"); + + // Validate reduction conditions + if (redOp) { + auto regTy = getType(); + auto memTy = getSrc().getType(); + auto maxnreg = getContextualMaxNReg(*this); + auto encodingInfoOr = computeTMemLdStEncodingInfo(regTy, memTy, maxnreg); + if (failed(encodingInfoOr)) + return emitOpError("failed to compute TMEM encoding info"); + + if (encodingInfoOr->unpacked) + return emitOpError( + "tmem_load reduction requires packed format (unpacked=false)"); + + // Verify that N dimension is in registers entirely, and is not sharded + // across threads. This could be relaxed in the future to only reduce the + // kReg bases along N then cross-warp/block reduction becomes needed. + auto kReg = StringAttr::get(regTy.getContext(), "register"); + int dimM = 0, dimN = 1; + auto regDims = toLinearEncoding(regTy).basesPerDim(kReg); + if (regDims[dimN] != toLinearLayout(regTy).getOutDimSizes().begin()[dimN] || + regDims[dimM] != 1) { + return emitOpError("tmem_load reduction with N dimension sharded across " + "threads is not supported."); + } + } + + return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), getType()); +} + +// -- TMEMAllocOp -- +LogicalResult TMEMAllocOp::verify() { + if (!isa( + getType().getEncoding())) + return emitOpError("should use tensor memory encoding"); + if (getSrc() && + failed(verifyTMEMOperand(*this, getSrc().getType(), getType(), "source"))) + return failure(); + return triton::gpu::verifyAllocOp(*this, getSrc(), getType()); +} + +void TMEMAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("tensor_memory_col_offset")) + return; + OpResult alloc = getOperation()->getOpResult(0); + effects.emplace_back(MemoryEffects::Allocate::get(), alloc, + TensorMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), alloc, + TensorMemory::get()); +} + +// -- TMEMCopyOp -- +LogicalResult TMEMCopyOp::verify() { + if (!isa( + getSrc().getType().getMemorySpace())) + return emitOpError("The source must be a shared memory buffer"); + + auto srcTy = cast(getSrc().getType()); + auto dstTy = cast(getDst().getType()); + if (srcTy.getShape() != dstTy.getShape()) + return emitOpError("source shape ") + << srcTy.getShape() << " must match destination shape " + << dstTy.getShape(); + + if (getBarrier() && !isa( + getBarrier().getType().getMemorySpace())) { + return emitOpError("The optional barrier should be a shared memory buffer"); + } + if (!getDst().getType().getMutableMemory()) { + return emitOpError("Cannot copy into an immutable alloc"); + } + auto sharedEnc = + dyn_cast(srcTy.getEncoding()); + if (sharedEnc.getAlignment() < 16) { + return emitOpError("Source must have at least 16-byte alignment to be " + "representable in a matrix descriptor."); + } + + auto mod = getOperation()->getParentOfType(); + unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + if (numCTAs != 1) + return emitOpError("NYI: Only one CTA is supported for now."); + + // Fp4 we could lift if we needed + auto nvmmaEnc = + dyn_cast(srcTy.getEncoding()); + if (nvmmaEnc && (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())) { + return emitOpError("The source should not be transposed or padded"); + } + if (isa(getDst().getType().getEncoding())) { + if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() != 0) { + return emitOpError("The source should not be swizzled for now"); + } + } else { + if (getSrc().getType().getShape() != getDst().getType().getShape()) { + return emitOpError( + "The source and destination must have the same shape."); + } + auto tmemEnc = dyn_cast( + getDst().getType().getEncoding()); + if (!tmemEnc) { + return emitOpError("Incorrect tmem layout."); + } + if (tmemEnc.getBlockM() != 128) { + return emitOpError("Tmem layout must have blockM=128."); + } + if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() == 0) { + return emitOpError("Source layout should be swizzled."); + } + // When we lift this, we should make sure we handle unpacked cleanly + if (srcTy.getElementType().getIntOrFloatBitWidth() != 32) { + return emitOpError("Source element type should be 32-bit."); + } + } + // Given that we want to support flexible input SMEM shapes, kinds of shape + // checking we can do here are limited. For simplicity, shape checking is + // omitted. + return success(); +} + +// -- TMEMSubSliceOp -- +LogicalResult TMEMSubSliceOp::verify() { + auto srcTy = cast(getSrc().getType()); + auto encoding = dyn_cast( + srcTy.getEncoding()); + if (!encoding) + return emitOpError("The source must be a tensor memory buffer."); + if (!llvm::is_contained({64, 128}, encoding.getBlockM())) { + return emitOpError("The source tensor memory descriptor must have a 128xN " + "or 64xN layout, got block_m=") + << encoding.getBlockM(); + } + auto dstTy = cast(getResult().getType()); + auto dstEncoding = dyn_cast( + dstTy.getEncoding()); + if (!dstEncoding) + return emitOpError("The destination must be a tensor memory buffer."); + if (dstEncoding.getBlockM() != encoding.getBlockM() || + dstEncoding.getCTASplitM() != encoding.getCTASplitM() || + dstEncoding.getCTASplitN() != encoding.getCTASplitN() || + dstEncoding.getColStride() != encoding.getColStride()) + return emitOpError("The destination must have the same block size and " + "CTASplit size as the source."); + return mlir::success(); +} + +void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state, + Value alloc, int offset, int size) { + auto allocTy = cast(alloc.getType()); + SmallVector shape(allocTy.getShape()); + shape.back() = size; + auto encoding = + cast(allocTy.getEncoding()); + unsigned newBlockN = std::min(encoding.getBlockN(), size); + auto newEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + builder.getContext(), encoding.getBlockM(), newBlockN, + encoding.getColStride(), encoding.getCTASplitM(), encoding.getCTASplitN(), + encoding.getTwoCTAs()); + auto subsliceType = gpu::MemDescType::get( + shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(), + allocTy.getMutableMemory(), allocTy.getAllocShape()); + build(builder, state, subsliceType, alloc, offset); +} + +// -- TensormapCreateOp -- +LogicalResult TensormapCreateOp::verify() { + auto rank = getBoxDim().size(); + if (getGlobalDim().size() != rank) { + return emitError("Rank mismatch for global dim. Got ") + << getGlobalDim().size() << " but expected " << rank; + } + if (getGlobalStride().size() + 1 != rank) { + return emitError("Rank mismatch for global stride. Got ") + << getGlobalStride().size() << " but expected " << rank - 1; + } + if (getElementStride().size() != rank) { + return emitError("Rank mismatch for element stride. Got ") + << getElementStride().size() << " but expected " << rank; + } + return success(); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp new file mode 100644 index 0000000000..09121115a5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.cpp @@ -0,0 +1,309 @@ +#include "triton/Dialect/TritonNvidiaGPU/IR/TensorMemoryUtils.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" + +#include +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace mlir::triton::nvidia_gpu { + +namespace { + +constexpr int maxRegisters = 256; +constexpr int largestTmemLoadStore = 128; + +// Similar to largestVectorisation in TritonGPUToLLVM/Utility.cpp +std::optional> +getVec(const LinearLayout &cvt, const LinearLayout &tile, int maxnreg) { + auto *ctx = cvt.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + auto kCol = StringAttr::get(ctx, "col"); + LinearLayout reps, vec; + ColumnAction perm; + // Heuristic: + // Do not use more than half the registers as otherwise it's prone to spilling + assert(maxnreg / 2 <= largestTmemLoadStore); + auto maxReg = maxnreg / 2; + // Heuristic: + // If maxnreg is 256 and we need more than one message, we don't use max + // vectorisation as ptxas' scheduler breaks... + if (maxnreg == 256 && cvt.getInDimSize(kReg) > maxReg) { + maxReg /= 2; + } + auto maxVec = maxReg / tile.getInDimSize(kReg); + int i = 1; + for (; i <= maxVec; i *= 2) { + vec = LinearLayout::identity1D(i, kReg, kCol); + auto vecTile = tile * vec; + auto maybePerm = regPermForDivide(cvt, vecTile, /*left=*/true); + if (!maybePerm) { + break; + } + // nb. We could remove this part once we are confident the algo works + perm = *maybePerm; + auto newCvt = maybePerm->apply(cvt); + auto maybeReps = getReps(newCvt, vecTile); + if (!maybeReps.has_value()) { + break; + } + reps = *maybeReps; + } + if (i == 1) { + // Couldn't lower the tile + return std::nullopt; + } + // i is the smallest power of 2 that *cannot* be used to lower the tile + // so we return i / 2. + assert(i > 1); + return std::make_tuple(std::move(reps), std::move(perm), + (i / 2) * tile.getInDimSize(kReg)); +} +} // namespace + +// Get the maximum number of registers per thread based on the context. This is +// by default 256, but it can be overridden by `ttg.maxnreg` set on the module +// or a contextual register limit set by the compiler on partitions. +int getContextualMaxNReg(Operation *op) { + // Check the immediate parent op to see if it places a register constraint. + auto getFromParent = [](Operation *op) -> std::optional { + Operation *parent = op->getParentOp(); + if (auto mod = dyn_cast(parent)) { + if (auto attr = mod->getAttrOfType(AttrMaxRegistersName)) + return attr.getInt(); + return {}; + } + + if (auto partitions = dyn_cast(parent)) { + // Check if the partition has reduced registers. + unsigned idx = op->getParentRegion()->getRegionNumber(); + if (auto actRegisters = partitions.getParentOp().getActualRegisters()) + return (*actRegisters)[1 + idx]; + return {}; + } + + if (auto wsOp = dyn_cast(op->getParentOp())) { + // Check the register usage of the default warpgroup. + if (auto actRegisters = wsOp.getActualRegisters()) + return actRegisters->front(); + return {}; + } + + return {}; + }; + + // PTXAS validates the register usage of `tcgen05.ld` and `tcgen05.st` + // instructions based on the static number of registers set on the module, not + // the dynamic allocation. This just means the register limit used for the + // purpose of subtiling TMEM messages cannot be higher than the module's. + auto mod = op->getParentOfType(); + int maxnreg = maxRegisters; + + for (; op != mod; op = op->getParentOp()) { + if (std::optional limit = getFromParent(op)) { + maxnreg = std::min(maxnreg, *limit); + break; + } + } + + if (auto maxnregAttr = mod->getAttrOfType(AttrMaxRegistersName)) + maxnreg = std::min(maxnreg, maxnregAttr.getInt()); + + return maxnreg; +} + +FailureOr +lowerTMemLdSt(const LinearLayout &cvt, int maxnreg, int bitwidth, bool isScales, + std::function emitError, + bool unpacked = false) { + // We will fill in the returned value recursively (if it exists) + + // Remove broadcasting in the registers + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtCvt = removeBroadcastSrc.apply(cvt); + auto info = lowerTMemLdSt(prmtCvt, maxnreg, bitwidth, isScales, emitError, + unpacked); + if (failed(info)) + return failure(); + info->broadcast = std::move(removeBroadcastSrc); + return info; + } + auto *ctx = cvt.getInDimNames().begin()->getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + auto kReg = S("register"); + auto kLane = S("lane"); + auto kRow = S("row"); + auto kCol = S("col"); + if (bitwidth < 32) { + LinearLayout quot; + int bestContig = 1; + for (int contig = 1; bitwidth * contig <= 32; contig *= 2) { + auto maybeQuot = + divideLeft(cvt, LinearLayout::identity1D(contig, kReg, kCol)); + if (!maybeQuot) + break; + quot = *maybeQuot; + bestContig = contig; + } + bool padding = false; + int newBitwidth = bitwidth; + if (bestContig > 1) { + // There are contiguous elements along kCol, so we can pack them into a + // larger dtype + unpacked = false; + newBitwidth = bitwidth * bestContig; + } else if (auto maybeQuot = divideLeft( + cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * + LinearLayout::identity1D(2, kReg, kCol)); + bitwidth == 16 && maybeQuot) { + // Unpacked just supported for bitwidth 16 + unpacked = true; + quot = *maybeQuot; + newBitwidth = 32; + } else if (auto maybeQuot = divideLeft( + cvt, LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth))) { + // We software-pad the elements when we either do not have enough elements + // to fill a full 32b register, e.g., colN = 1 and colStride != 1 or when + // bitwidth == 8 (this happens with scales with K=1). + // These two cases are mostly supported for testing purposes. + unpacked = bitwidth == 16; + quot = *maybeQuot; + padding = true; + newBitwidth = 32; + } else { + if (emitError) { + emitError() << "Failed to lower TMEM load/store: TMEM layout is not " + "packed or unpacked"; + } + return failure(); + } + // When unpacked each register moves 32/bitwidth (= 2) columns + if (unpacked) { + quot = LinearLayout::zeros1D(1, kReg, kCol, 32 / bitwidth) * quot; + } + auto info = lowerTMemLdSt(quot, maxnreg, newBitwidth, isScales, emitError, + unpacked); + if (failed(info)) + return failure(); + if (bestContig > 1) { + info->vec = bestContig; + } + if (unpacked) { + info->unpacked = true; + } + if (padding) { + info->padding = true; + } + return info; + } + + assert(bitwidth == 32); + + // The algorithm goes as: + // - Try to match the tile with one of the standard messages + // - If it doesn't match, we use the 16x32bx2 message + // Note that it can match one and only one of the layouts, even after register + // reordering, as the layouts yield predetermined positions for the lanes + // We store the instruction, the resulting reps layout, the permutation and + // the number of registers per message + std::optional msgInfo; + for (auto atom : {TMemAccessAtom::I32x32b, TMemAccessAtom::I16x256b, + TMemAccessAtom::I16x64b, TMemAccessAtom::I16x128b}) { + auto tile = getTileLayout(ctx, atom, unpacked, /*withWarp=*/true); + auto maybeReps = getVec(cvt, tile, maxnreg); + if (maybeReps) { + // Cannot match more than one + msgInfo = {atom, std::get<0>(*maybeReps), std::get<1>(*maybeReps), + std::get<2>(*maybeReps)}; + break; + } + } + std::optional secondHalfOffset = std::nullopt; + if (!msgInfo) { + // Quotient by the smaller tile and then, if possible, we set the + // secondHalfOffset to the last kLane basis + auto tile = getTileLayout(ctx, TMemAccessAtom::I16x32bx2, unpacked, + /*withWarp=*/true); + auto maybeReps = getVec(cvt, tile, maxnreg); + if (maybeReps) { + auto [reps, perm, numRegsPerMessage] = std::move(*maybeReps); + // Find the last kLane basis and use it as secondHalfOffset + auto row = reps.getBasis(kLane, 4, kRow); + auto col = reps.getBasis(kLane, 4, kCol); + secondHalfOffset = (row << 16) | col; + if (*secondHalfOffset == 0) { + // Workaround for ptxas bug, we cannot use secondHalfOffset = 0 to write + // only 16 elements. We use secondHalfOffset = 1 instead and we pad the + // allocation. + if (!isScales) { + if (emitError) { + emitError() + << "Only supported for scales as we pad the allocation."; + } + return failure(); + } + secondHalfOffset = 1; + } + // We "quotient it out", meaning we remove the last basis from reps + auto basis = reps.getBases(); + basis[kLane][4] = {0, 0}; + reps = LinearLayout(std::move(basis), reps.getOutDims(), + /*isSurjective=*/false); + msgInfo = {TMemAccessAtom::I16x32bx2, reps, perm, numRegsPerMessage}; + } + } + + if (!msgInfo) { + if (emitError) { + emitError() + << "Failed to lower TMEM load/store: unsupported dst layout\n" + + cvt.toString(); + } + return failure(); + } + auto info = std::move(*msgInfo); + info.secondHalfOffset = secondHalfOffset; + return info; +} + +FailureOr +computeTMemLdStEncodingInfo(RankedTensorType regTy, MemDescType memTy, + int maxnreg, + std::function emitError) { + auto memLayout = toLinearLayout(memTy); + auto regLayout = toLinearLayout(regTy); + auto cvt = regLayout.invertAndCompose(memLayout); + auto *ctx = regTy.getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + auto kWarp = S("warp"); + auto kRow = S("row"); + // Warps 0-3 must map to row=32 and row=64 whether with broadcasting or not + if (!(regLayout.getBasis(kWarp, 0) == memLayout.getBasis(kRow, 5) && + regLayout.getBasis(kWarp, 1) == memLayout.getBasis(kRow, 6))) { + if (emitError) { + emitError() << "warps=1,2 must map to rows=32,64. Got:\n" + << regLayout.toString() << "\n" + << memLayout.toString(); + } + return failure(); + } + // Map warp bases to row=32 and row=64 in the cvt. This would be done + // automatically in `invertAndCompose` if we had a different dimension name + // for these rows. We can do this in the future if needed. + auto bases = cvt.getBases(); + bases[kWarp][0] = {32, 0}; + bases[kWarp][1] = {64, 0}; + cvt = LinearLayout(std::move(bases), cvt.getOutDims(), + /*isSurjective=*/cvt.isSurjective()); + + bool isScales = isa(memTy.getEncoding()); + int bitwidth = memTy.getElementTypeBitWidth(); + return lowerTMemLdSt(cvt, maxnreg, bitwidth, isScales, emitError); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..7715cc9861 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,25 @@ +add_triton_library(TritonNvidiaGPUTransforms + CheckMatmulTwoCTAs.cpp + FenceInsertion.cpp + InterleaveTMem.cpp + MMALowering.cpp + OptimizeDescriptorEncoding.cpp + OptimizeTMemLayouts.cpp + PlanCTA.cpp + PromoteLHSToTMem.cpp + ProxyFenceInsertion.cpp + RemoveTMEMTokens.cpp + TensorMemoryAllocation.cpp + TMALowering.cpp + TMAUtilities.cpp + + DEPENDS + TritonNvidiaGPUTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CheckMatmulTwoCTAs.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CheckMatmulTwoCTAs.cpp new file mode 100644 index 0000000000..c5b1ddf37a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/CheckMatmulTwoCTAs.cpp @@ -0,0 +1,63 @@ +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Visitors.h" + +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir::triton::nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUCHECKMATMULTWOCTAPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +class TritonNvidiaGPUCheckMatmulTwoCTAPass + : public impl::TritonNvidiaGPUCheckMatmulTwoCTAPassBase< + TritonNvidiaGPUCheckMatmulTwoCTAPass> { +public: + using impl::TritonNvidiaGPUCheckMatmulTwoCTAPassBase< + TritonNvidiaGPUCheckMatmulTwoCTAPass>:: + TritonNvidiaGPUCheckMatmulTwoCTAPassBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + Operation *firstMatmul = nullptr; + bool firstTwoCTA = false; + + WalkResult result = mod.walk([&](ttng::TCGen5MMAOp op) { + bool currentTwoCTA = op.getTwoCtas(); + if (!firstMatmul) { + firstMatmul = op; + firstTwoCTA = currentTwoCTA; + return WalkResult::advance(); + } + if (currentTwoCTA != firstTwoCTA) { + auto diag = op.emitError() + << "inconsistent two_ctas setting across matmuls; " + "expected all matmuls to " + << (firstTwoCTA ? "enable" : "disable") << " two_ctas."; + diag.attachNote(firstMatmul->getLoc()) + << "first matmul here has two_ctas=" + << (firstTwoCTA ? "true" : "false") << "."; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } + + bool twoCTAValue = firstMatmul ? firstTwoCTA : false; + mod->setAttr(AttrTwoCTAsName, BoolAttr::get(mod.getContext(), twoCTAValue)); + } +}; + +} // namespace + +} // namespace mlir::triton::nvidia_gpu diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp new file mode 100644 index 0000000000..f831ed3a9b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -0,0 +1,151 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// +// This pass works after all other passes, inserting fences to ensure that +// memory operations are properly ordered across generic and async proxy. +// +//===----------------------------------------------------------------------===// + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONGPUFENCEINSERTION +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +struct FenceInsertionPass + : public impl::TritonGPUFenceInsertionBase { + +public: + using impl::TritonGPUFenceInsertionBase< + FenceInsertionPass>::TritonGPUFenceInsertionBase; + // TODO: support more general patterns to insert fences. eg. any op(generic) + // to shared in use-def chain which refers by async proxy. We have generic( + // convertlayout with sts/stmatix) + fence + async(wgmma) up to now + void runOnOperation() override { + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; + ModuleOp mod = getOperation(); + mod.walk([&](DotOpInterface dotOp) { + Value a = dotOp.getA(); + Value b = dotOp.getB(); + SmallVector copyRegToSharedOpsA = findCopyRegToSharedOps(a); + SmallVector copyRegToSharedOpsB = findCopyRegToSharedOps(b); + if (copyRegToSharedOpsA.empty() && copyRegToSharedOpsB.empty()) + return WalkResult::advance(); + + OpBuilder builder(dotOp); + auto fence = FenceAsyncSharedOp::create(builder, dotOp.getLoc(), + /*bCluster=*/false); + // If there is all the dependencies are outside of the loop try to hoist + // the fence. + while (auto loopOp = fence->getParentOfType()) { + if (!copyRegToSharedOpsA.empty() && + llvm::any_of(copyRegToSharedOpsA, + [&](Operation *op) { return loopOp->isAncestor(op); })) + break; + if (!copyRegToSharedOpsB.empty() && + llvm::any_of(copyRegToSharedOpsB, + [&](Operation *op) { return loopOp->isAncestor(op); })) + break; + loopOp.moveOutOfLoop(fence); + } + + // If the previous op is already a fence, this one isn't needed. + if (auto lastFence = + dyn_cast_or_null(fence->getPrevNode())) { + if (lastFence.getBCluster() == fence.getBCluster()) + fence.erase(); + } + + return WalkResult::advance(); + }); + } + +private: + // Return true if the operand depends on a copy from register to shared. + SmallVector findCopyRegToSharedOps(Value operand) { + DenseSet visited; + llvm::SetVector result; + findCopyRegToSharedOps(operand, visited, result); + return result.takeVector(); + } + + void findCopyRegToSharedOps(Value operand, DenseSet &visited, + llvm::SetVector &result) { + // If the value has already been visited we can safely return false as we + // would early return when true. + if (visited.count(operand)) + return; + visited.insert(operand); + if (!isa(operand.getType())) + return; + + auto op = operand.getDefiningOp(); + if (op) { + // reach an alloc copying from register, we need a fence. + if (auto localAlloc = dyn_cast(op)) { + if (localAlloc.getSrc()) { + result.insert(op); + } + // Check if there are local_store ops that write to that buffer. + for (auto user : localAlloc.getResult().getUsers()) { + while (user->hasOneUse() && + user->hasTrait()) { + user = *user->getUsers().begin(); + } + if (isa(user)) { + result.insert(user); + return; + } + } + } + // if it is not an alloc, iterate over the operands. + for (auto v : op->getOperands()) { + findCopyRegToSharedOps(v, visited, result); + } + return; + } + + // reach BlockArgument + BlockArgument arg = cast(operand); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + // look through ForOp iter argument + if (auto forOp = dyn_cast(argOwner)) { + assert(argNum != 0 && "induction var cannot be memdesc type"); + --argNum; + // prologue + findCopyRegToSharedOps(forOp.getInitArgs()[argNum], visited, result); + // yield + auto yieldOp = forOp.getBody()->getTerminator(); + Value v = yieldOp->getOperand(argNum); + findCopyRegToSharedOps(v, visited, result); + return; + } + + // look through `ttg.warp_specialize`. + if (auto wsOp = dyn_cast(argOwner)) { + findCopyRegToSharedOps(wsOp.getExplicitCaptures()[argNum], visited, + result); + return; + } + + // Conservatively return true for other ops + result.insert(argOwner); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp new file mode 100644 index 0000000000..29c1955eac --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp @@ -0,0 +1,283 @@ +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/ADT/AddressRanges.h" + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUINTERLEAVETMEMPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +// If we don't know the effects of the op, we add all possible effects. +void addAllValuelessEffects( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); +} + +bool collectEffects(Operation *op, + SmallVectorImpl &effects) { + // Collect effect instances the operation. Note that the implementation of + // getEffects erases all effect instances that have the type other than the + // template parameter so we collect them first in a local buffer and then + // copy. + if (auto iface = dyn_cast(op)) { + SmallVector localEffects; + iface.getEffects(localEffects); + llvm::append_range(effects, localEffects); + return true; + } + if (op->hasTrait()) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &innerOp : block) + if (!collectEffects(&innerOp, effects)) + return false; + } + } + return true; + } + + // We need to be conservative here in case the op doesn't have the interface + // and assume it can have any possible effect. + addAllValuelessEffects(effects); + return false; +} + +struct AccessRange { + SmallVector> ranges; + unsigned rankOffset = 0; +}; + +std::pair findBufferAccess(Value a); + +std::pair +findBufferAccessMemdescSubview(Operation *subview) { + OpBuilder builder(subview); + Location loc = subview->getLoc(); + TypedValue src; + SmallVector shape; + SmallVector offsets; + if (auto indexOp = dyn_cast(subview)) { + src = indexOp.getSrc(); + shape = to_vector(indexOp.getType().getShape()); + offsets = {indexOp.getIndex()}; + for (auto i : llvm::seq(std::max(0, shape.size() - 1))) + offsets.push_back(arith::ConstantIntOp::create(builder, loc, 0, 32)); + } else { + auto subsliceOp = cast(subview); + src = subsliceOp.getSrc(); + shape = to_vector(subsliceOp.getType().getShape()); + for (auto offset : subsliceOp.getOffsets()) + offsets.push_back(arith::ConstantIntOp::create(builder, loc, offset, 32)); + } + auto [alloc, parentAccess] = findBufferAccess(src); + if (!alloc) + return {}; + // Handle subview of a subview. The first `rankOffset` access sizes are + // the same as in the parent access. + AccessRange childAccess; + for (auto i : llvm::seq(parentAccess.rankOffset)) + childAccess.ranges.push_back(parentAccess.ranges[i]); + + // The subview may have a smaller rank, in which case its access size is + // just 1 for the higher dims. + childAccess.rankOffset = src.getType().getRank() - shape.size(); + for (auto [i, offset] : llvm::enumerate(offsets)) { + auto parentRange = parentAccess.ranges[i + parentAccess.rankOffset]; + if (!parentRange) { + childAccess.ranges.push_back({}); + continue; + } + + // If the offset is not known, then the entire dim may be accessed. + APInt value; + if (!matchPattern(offset, m_ConstantInt(&value))) { + childAccess.ranges.push_back({}); + continue; + } + + uint64_t accessStart = parentRange->start() + value.getSExtValue(); + uint64_t accessSize = 1; + if (i >= childAccess.rankOffset) + accessSize = shape[i - childAccess.rankOffset]; + childAccess.ranges.push_back({{accessStart, accessStart + accessSize}}); + } + return {alloc, std::move(childAccess)}; +} + +// Simple local alias analysis that looks for a single underlying allocation and +// an access subrange. +std::pair findBufferAccess(Value a) { + // Handle block arguments. + if (auto arg = dyn_cast(a)) { + Operation *parentOp = arg.getOwner()->getParentOp(); + + // Look through `ttg.warp_specialize` explicit captures. + if (auto wsOp = dyn_cast(parentOp)) { + return findBufferAccess(wsOp.getExplicitCaptures()[arg.getArgNumber()]); + } + + // Unknown block argument. + return {}; + } + + Operation *defOp = a.getDefiningOp(); + // Accessing the alloc accesses the whole buffer. + if (auto alloc = dyn_cast(defOp)) { + AccessRange access; + for (uint64_t dim : alloc.getType().getShape()) + access.ranges.push_back({{0, dim}}); + return {a, std::move(access)}; + } + + // Trans and Reshape views don't change the access size. + if (isa(defOp)) { + return findBufferAccess(defOp->getOperand(0)); + } + + // Subviews can reduce the access sizes. + if (isa(defOp)) { + return findBufferAccessMemdescSubview(defOp); + } + + // Subslice is a subview only on the N dimension. + if (auto subslice = dyn_cast(defOp)) { + auto [alloc, parentAccess] = findBufferAccess(subslice.getSrc()); + if (!alloc) + return {}; + if (!parentAccess.ranges[1]) + return {alloc, parentAccess}; + uint64_t mStart = parentAccess.ranges[1]->start() + subslice.getN(); + uint64_t mSize = subslice.getType().getShape()[1]; + AccessRange childAccess = parentAccess; + childAccess.ranges[1] = {{mStart, mStart + mSize}}; + return {alloc, std::move(childAccess)}; + } + + // Unknown defining op. + return {}; +} + +bool tmemMayAlias(Value a, Value b) { + auto [aAlloc, aRanges] = findBufferAccess(a); + auto [bAlloc, bRanges] = findBufferAccess(b); + // If the underlying buffer was not identified, assume mayalias. + if (!aAlloc || !bAlloc) + return true; + // If the buffers are different, they don't alias. + if (aAlloc != bAlloc) + return false; + // If the access ranges along any dimension are known to not overlap, then the + // accesses don't alias. + for (auto [aRange, bRange] : llvm::zip(aRanges.ranges, bRanges.ranges)) { + // If either access range at this dim is unknown, we can't determine if they + // don't overlap. + if (!aRange || !bRange) + continue; + // The access ranges are known and don't overlap. + if (!aRange->intersects(*bRange)) + return false; + } + return true; +} + +// Sink tmem_loads as close to their use as possible to reduce register +// pressure. +bool sinkOps(Value buffer, ArrayRef useChain) { + Operation *insertBefore = nullptr; + Operation *next = useChain.back()->getNextNode(); + while (next && !next->hasTrait()) { + insertBefore = next; + bool dep = false; + for (auto operand : getNestedOperands(next)) { + if (llvm::any_of(useChain, [&](Operation *op) { + return llvm::is_contained(op->getResults(), operand); + })) { + dep = true; + break; + } + } + // Don't sink past barrier signals, since they may guard the liverange + // of the buffer. + if (isa(next)) + break; + if (!isMemoryEffectFree(next)) { + SmallVector effects; + collectEffects(next, effects); + for (auto effect : effects) { + // Look for potentially aliasing write or free effects. + if (!isa(effect.getEffect())) + continue; + if (isa(effect.getResource())) { + dep = true; + break; + } + if (isa(effect.getResource()) && + (!effect.getValue() || tmemMayAlias(effect.getValue(), buffer))) { + dep = true; + break; + } + } + } + if (dep) + break; + next = next->getNextNode(); + } + if (insertBefore && insertBefore != useChain.back()->getNextNode()) { + for (Operation *op : useChain) + op->moveBefore(insertBefore); + return true; + } + return false; +} + +// Try to sink a load and a collection of its users. +bool trySinkOp(Operation *op, Value buffer) { + SmallVector useChain{op}; + while (useChain.back()->hasOneUse() && + isPure(*useChain.back()->user_begin()) && + useChain.back()->getNextNode() == *useChain.back()->user_begin()) { + useChain.push_back(*useChain.back()->user_begin()); + } + return sinkOps(buffer, useChain); +} + +} // anonymous namespace + +struct TritonNvidiaGPUInterleaveTMemPass + : public impl::TritonNvidiaGPUInterleaveTMemPassBase< + TritonNvidiaGPUInterleaveTMemPass> { + using impl::TritonNvidiaGPUInterleaveTMemPassBase< + TritonNvidiaGPUInterleaveTMemPass>::TritonNvidiaGPUInterleaveTMemPassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + SmallVector> opsToSink; + m.walk([&](Operation *op) { + if (auto load = dyn_cast(op)) + opsToSink.emplace_back(load, load.getSrc()); + else if (auto alloc = dyn_cast(op)) + opsToSink.emplace_back(alloc, alloc.getResult()); + }); + for (auto [op, buffer] : opsToSink) { + while (trySinkOp(op, buffer)) { + // Keep trying to sink loads and their users. + } + } + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp new file mode 100644 index 0000000000..69ac7caebf --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp @@ -0,0 +1,222 @@ +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUMMALOWERINGPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +class SyncMMALowering : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(MMAv5OpInterface op, + PatternRewriter &rewriter) const override { + // If the op doesn't have synchronous semantic skip the pattern. + if (op.isAsync()) + return failure(); + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx); + auto numCTAs = gpu::lookupNumCTAs(op); + auto barrierCGALayout = ttg::CGAEncodingAttr::get1DLayout(ctx, numCTAs); + auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( + ctx, 1, 1, 1, {0}, barrierCGALayout); + ttg::MemDescType barrierMemDescType = + ttg::MemDescType::get({numCTAs}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = + ttg::LocalAllocOp::create(rewriter, loc, barrierMemDescType, Value()); + InitBarrierOp::create(rewriter, loc, barrierAlloc, 1); + op.addCompletionBarrier(barrierAlloc, + arith::ConstantIntOp::create(rewriter, loc, 1, 1)); + op.setIsAsync(true); + + rewriter.setInsertionPointAfter(op); + Value phase = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + WaitBarrierOp::create(rewriter, loc, barrierAlloc, phase, + op.getPredicate()); + InvalBarrierOp::create(rewriter, loc, barrierAlloc); + return success(); + } +}; + +struct TCGen5MMAScaleSharedToTmemConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Create a tmem_copy of scales from shared memory to tmem. `rows` is the M or + // N of the MMA operation (for LHS or RHS respectively). + bool lowerScaleToTmem(OpOperand &operand, PatternRewriter &rewriter, + int rows) const { + Location loc = operand.getOwner()->getLoc(); + MLIRContext *context = operand.getOwner()->getContext(); + Attribute tensorMemorySpace = TensorMemorySpaceAttr::get(context); + auto oldType = cast(operand.get().getType()); + auto numElems = product(oldType.getShape()); + Type elType = oldType.getElementType(); + ttg::CGAEncodingAttr CGALayout = ttg::getCGALayout(oldType.getEncoding()); + auto CTASplitNum = CGALayout.getCTASplitNum(); + // Distribute the scales across the rows of the MMA operation. + SmallVector shape = {rows, numElems / rows}; + Attribute scaleEncoding = TensorMemoryScalesEncodingAttr::get( + context, CTASplitNum[0], CTASplitNum[1]); + Type scaleAType = + ttg::MemDescType::get(shape, elType, scaleEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto tmemAlloc = TMEMAllocOp::create(rewriter, loc, scaleAType, Value()); + TMEMCopyOp::create(rewriter, loc, operand.get(), tmemAlloc, + /*barrier*/ Value()); + operand.set(tmemAlloc); + return true; + } + + LogicalResult matchAndRewrite(TCGen5MMAScaledOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + auto aScaleType = op.getAScale().getType(); + auto bScaleType = op.getBScale().getType(); + int blockM = op.getBlockM(); + int blockN = op.getBlockN(); + bool anyChanged = false; + if (isa(aScaleType.getMemorySpace())) { + anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter, blockM); + } + if (isa(bScaleType.getMemorySpace())) { + anyChanged = lowerScaleToTmem(op.getBScaleMutable(), rewriter, blockN); + } + return LogicalResult::success(anyChanged); + } +}; + +std::pair, SmallVector> +collectCommitOpsAfter(MMAv5OpInterface mmaOp) { + auto isConstTrue = [](Value v) { + if (auto constOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constOp.getValueAttr())) { + return attr.getValue(); + } + } + return false; + }; + + SmallVector commitOps; + SmallVector commitPredicates; + auto mmaPred = mmaOp.getPredicate(); + Operation *nextOp = mmaOp->getNextNode(); + + while (nextOp) { + if (auto commit = dyn_cast(nextOp)) { + // If the mma predicate is true, or mma and commit ops use the same + // predicate, it is safe to merge them + if (isConstTrue(mmaPred) || mmaPred == commit.getPred()) { + commitOps.push_back(commit); + commitPredicates.push_back(commit.getPred()); + } + } else if (!isPure(nextOp)) { + // Only move commits across pure ops. We also bail here when encountering + // another MMAv5 op. + break; + } + nextOp = nextOp->getNextNode(); + } + + return {commitOps, commitPredicates}; +} + +// Return false if defining ops cannot be moved above the target op +bool moveDefiningOpsBefore(Value val, Operation *target) { + SetVector toMove; + + std::function collectOpsToMove = [&](Value val) { + if (auto defOp = val.getDefiningOp()) { + if (defOp->getBlock() == target->getBlock() && + target->isBeforeInBlock(defOp)) { + if (!isPure(defOp)) { + // This defOp needs to move above the target op, but it is unsafe due + // to impurity. + return false; + } + for (Value operand : defOp->getOperands()) { + if (!collectOpsToMove(operand)) { + return false; + } + } + toMove.insert(defOp); + } + } + return true; + }; + + if (!collectOpsToMove(val)) { + return false; + } + + for (Operation *op : toMove) { + op->moveBefore(target); + } + + return true; +} + +class MergeCommitIntoMMA : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(MMAv5OpInterface op, + PatternRewriter &rewriter) const override { + auto [commitOps, predicates] = collectCommitOpsAfter(op); + if (commitOps.size() == 0) { + return llvm::failure(); + } + for (auto [commit, pred] : llvm::zip(commitOps, predicates)) { + if (!pred) { + pred = arith::ConstantIntOp::create(rewriter, op.getLoc(), true, 1); + } + if (!moveDefiningOpsBefore(commit.getBarrier(), op) || + !moveDefiningOpsBefore(pred, op)) { + // Give up merging a commit if its defining ops cannot be moved above + // the mma op. + continue; + } + op.addCompletionBarrier(commit.getBarrier(), pred); + rewriter.eraseOp(commit); + } + return success(); + } +}; + +} // anonymous namespace + +class TritonNvidiaGPUMMALoweringPass + : public impl::TritonNvidiaGPUMMALoweringPassBase< + TritonNvidiaGPUMMALoweringPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp new file mode 100644 index 0000000000..de85c593a5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -0,0 +1,400 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/PriorityWorklist.h" +#include +#include + +namespace ttg = mlir::triton::gpu; + +namespace { + +struct UseInfo { + TypedValue descriptor; + Operation *use; + Attribute desiredSharedEncoding; + SmallVector shape; + ttg::CGAEncodingAttr cgaLayout; +}; + +static bool isTMACompatibleEncoding(Attribute enc) { + if (auto nvmma = dyn_cast(enc)) { + return !nvmma.getTransposed(); + } + return false; +} + +Attribute findLoadEncodingFromUsers(Operation *op) { + // Ignore multiple users and just pick the first compatible layout + for (auto use : op->getUsers()) { + if (auto alloc = dyn_cast(use)) { + auto enc = alloc.getType().getEncoding(); + if (isTMACompatibleEncoding(enc)) + return enc; + } else if (auto store = dyn_cast(use)) { + auto enc = store.getDst().getType().getEncoding(); + if (isTMACompatibleEncoding(enc)) + return enc; + } + } + return {}; +} + +SmallVector expandToRank(ArrayRef shape, int rank) { + SmallVector result(rank, 1); + assert(shape.size() <= rank); + auto rankDiff = rank - shape.size(); + std::copy(shape.begin(), shape.end(), result.begin() + rankDiff); + return result; +} + +std::optional getUseInfo(Operation *op) { + UseInfo info; + info.use = op; + if (auto load = dyn_cast(op)) { + info.descriptor = load.getDesc(); + info.desiredSharedEncoding = findLoadEncodingFromUsers(op); + auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding + : load.getType().getEncoding(); + info.cgaLayout = ttg::getCGALayout(encoding); + auto shape = load.getResult().getType().getShape(); + auto rank = load.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + return info; + } + if (auto gather = dyn_cast(op)) { + info.descriptor = gather.getDesc(); + info.desiredSharedEncoding = findLoadEncodingFromUsers(op); + auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding + : gather.getType().getEncoding(); + info.cgaLayout = ttg::getCGALayout(encoding); + auto shape = gather.getResult().getType().getShape(); + auto rank = gather.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + return info; + } + if (auto store = dyn_cast(op)) { + info.descriptor = store.getDesc(); + auto encoding = store.getSrc().getType().getEncoding(); + info.cgaLayout = ttg::getCGALayout(encoding); + auto shape = store.getSrc().getType().getShape(); + auto rank = store.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + return info; + } + return std::nullopt; +} + +struct EncodingInfo { + Attribute desiredEncoding; + ttg::CGAEncodingAttr cgaLayout; + // Shape may be different from the descriptor block shape for gather/scatter + // use case + SmallVector shape; + bool forcedToDefault = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + cgaLayout == other.cgaLayout && + forcedToDefault == other.forcedToDefault && shape == other.shape; + } +}; + +} // namespace + +template <> struct std::hash { + size_t operator()(const EncodingInfo &einfo) const { + return llvm::hash_combine(einfo.desiredEncoding, einfo.cgaLayout, + einfo.forcedToDefault, + ArrayRef(einfo.shape)); + } +}; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUOPTIMIZEDESCRIPTORENCODINGPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +const EncodingInfo *internEncoding(std::unordered_set &encodings, + EncodingInfo info) { + return &*encodings.insert(info).first; +} + +EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs, + unsigned rank) { + EncodingInfo result; + // Always propagate forcedToDefault + result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault; + + if (result.forcedToDefault) + return result; + + if (lhs.shape.empty() || lhs.shape == rhs.shape) + result.shape = rhs.shape; + else if (rhs.shape.empty()) + result.shape = lhs.shape; + else { + assert(lhs.shape.size() == rhs.shape.size()); + auto rank = lhs.shape.size(); + result.shape.reserve(rank); + for (int i = 0; i < rank; ++i) + result.shape.push_back(std::min(lhs.shape[i], rhs.shape[i])); + } + + SetVector cgaLayouts; + if (lhs.cgaLayout) + cgaLayouts.insert(lhs.cgaLayout); + if (rhs.cgaLayout) + cgaLayouts.insert(rhs.cgaLayout); + + auto getDefaultLayout = [&](ttg::CGAEncodingAttr encoding) { + // The default layout puts all the CTAs in the last dimension + // We do this as this function needs to be commutative for all encodings + // This heuristic could be improved if needed + auto ctx = encoding.getContext(); + auto kBlock = StringAttr::get(ctx, "block"); + auto dims = triton::standardOutDimNames(ctx, rank); + auto numCTAs = encoding.getLinearLayout().getInDimSize(kBlock); + LinearLayout llDefault; + for (int i = 0; i < rank - 1; ++i) { + llDefault *= LinearLayout::identity1D(1, kBlock, dims[i]); + } + llDefault *= LinearLayout::identity1D(numCTAs, kBlock, dims.back()); + return ttg::CGAEncodingAttr::get(ctx, llDefault); + }; + + switch (cgaLayouts.size()) { + case 2: + // if we find clashing CGALayouts, fallback to default + result.cgaLayout = getDefaultLayout(lhs.cgaLayout); + break; + case 1: + result.cgaLayout = cgaLayouts[0]; + break; + default: + break; + } + + SetVector desiredEncodings; + if (lhs.desiredEncoding) + desiredEncodings.insert(lhs.desiredEncoding); + if (rhs.desiredEncoding) + desiredEncodings.insert(rhs.desiredEncoding); + + switch (desiredEncodings.size()) { + case 2: + // if we find clashing encodings, fallback to default + result.forcedToDefault = true; + break; + case 1: + result.desiredEncoding = desiredEncodings[0]; + break; + default: + break; + } + return result; +} + +Attribute getFallbackSharedEncoding(RankedTensorType tensorType, + ttg::CGAEncodingAttr cgaLayout, + ArrayRef usageShape, + unsigned numCTAs) { + auto ctx = tensorType.getContext(); + SmallVector order; + for (int i = tensorType.getRank() - 1; i >= 0; --i) + order.push_back(i); + + ArrayRef shape = + usageShape.empty() ? tensorType.getShape() : usageShape; + if (!cgaLayout) { + // Arbitrarily distribute along the last dim + SmallVector ctasPerCGA(tensorType.getRank(), 1); + ctasPerCGA.back() = numCTAs; + cgaLayout = ttg::CGAEncodingAttr::fromSplitParams(ctx, ctasPerCGA, + ctasPerCGA, order); + } else if (cgaLayout.getRank() != tensorType.getRank()) + cgaLayout = updateCGALayoutForShape(cgaLayout, shape); + + return ttg::NVMMASharedEncodingAttr::get(ctx, shape, order, cgaLayout, + tensorType.getElementType(), + /*fp4Padded*/ false); +} + +TensorDescType getTensorDescTypeWithEncoding(Operation *op, + RankedTensorType existingTy, + Attribute encoding) { + auto sharedEnc = cast(encoding); + encoding = updateEncodingForShape(op, sharedEnc, existingTy); + auto blockTy = existingTy.cloneWithEncoding(encoding); + return TensorDescType::get(existingTy.getContext(), blockTy); +} + +void assignMemoryLayouts(FuncOp &func) { + std::unordered_set encodings; + llvm::MapVector, const EncodingInfo *> + valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef descValues, EncodingInfo info) { + for (auto value : descValues) { + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr != valueToEncodingInfo.end()) + info = combineEncodings(*itr->second, info, + typedVal.getType().getBlockType().getRank()); + } + + auto einfo = internEncoding(encodings, info); + for (auto value : descValues) { + auto typedVal = cast>(value); + auto res = valueToEncodingInfo.try_emplace(typedVal, einfo); + if (res.second) { + worklist.insert(typedVal); + } else if (res.first->second != einfo) { + res.first->second = einfo; + worklist.insert(typedVal); + } + } + }; + + // 1. Set seed values from either TMA ops, or device function boundaries for + // which we fallback to default encoding + auto isKernel = triton::isKernel(func); + for (auto blockArg : func.getBlocks().front().getArguments()) + if (auto desc = dyn_cast>(blockArg)) + updateEncoding({desc}, + EncodingInfo{{}, {}, {}, /*forcedToDefault=*/!isKernel}); + + func.walk([&](Operation *op) { + if (auto info = getUseInfo(op)) { + updateEncoding(info->descriptor, + EncodingInfo{info->desiredSharedEncoding, info->cgaLayout, + info->shape}); + } else { + bool forcedToDefault = isa(op); + auto einfo = + internEncoding(encodings, EncodingInfo{{}, {}, {}, forcedToDefault}); + + auto setEncoding = [&](Value v) { + auto typedVal = cast>(v); + valueToEncodingInfo.try_emplace(typedVal, einfo); + if (forcedToDefault) + worklist.insert(typedVal); + }; + for (auto result : op->getResults()) + if (auto desc = dyn_cast>(result)) + setEncoding(desc); + + for (auto arg : op->getOperands()) + if (auto desc = dyn_cast>(arg)) + setEncoding(desc); + } + }); + + // 2. Propagate encoding info through the graph until fixed point + while (!worklist.empty()) { + auto desc = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : desc.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(desc)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{}); + } + } else if (auto blockArg = dyn_cast(desc)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{}); + } + } + } + + // 3. Transfer propagated encodings into the graph + auto ctx = func.getContext(); + auto numCTAs = gpu::lookupNumCTAs(func); + for (auto &[desc, einfo] : valueToEncodingInfo) { + auto existingTy = desc.getType().getBlockType(); + Attribute newEncoding; + if (einfo->desiredEncoding) { + newEncoding = einfo->desiredEncoding; + } else if (einfo->forcedToDefault) { + newEncoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs); + } else { + newEncoding = getFallbackSharedEncoding(existingTy, einfo->cgaLayout, + einfo->shape, numCTAs); + } + desc.setType(getTensorDescTypeWithEncoding(desc.getDefiningOp(), existingTy, + newEncoding)); + } + + SmallVector argTys(func.getBlocks().front().getArgumentTypes()); + SmallVector resultTys(func.getResultTypes()); + for (auto [i, resultTy] : llvm::enumerate(resultTys)) { + if (auto descTy = dyn_cast(resultTy)) { + auto encoding = + getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs); + resultTys[i] = getTensorDescTypeWithEncoding( + nullptr, descTy.getBlockType(), encoding); + } + } + func.setFunctionType(FunctionType::get(ctx, argTys, resultTys)); +} + +void assignMemoryLayouts(ModuleOp &mod) { + for (auto &op : *mod.getBody()) { + if (auto func = dyn_cast(&op)) { + assignMemoryLayouts(func); + } + } +} + +} // anonymous namespace + +class TritonNvidiaGPUOptimizeDescriptorEncodingPass + : public impl::TritonNvidiaGPUOptimizeDescriptorEncodingPassBase< + TritonNvidiaGPUOptimizeDescriptorEncodingPass> { +public: + using BaseT = TritonNvidiaGPUOptimizeDescriptorEncodingPassBase< + TritonNvidiaGPUOptimizeDescriptorEncodingPass>; + using BaseT::BaseT; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + assignMemoryLayouts(m); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp new file mode 100644 index 0000000000..082d3d4f7a --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp @@ -0,0 +1,448 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUOPTIMIZETMEMLAYOUTSPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +// clang-format off +// Converts: +// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> +// -> tensor<128x256xf32, #blocked> +// %r = tt.reshape %l : tensor<128x256xf32, #blocked> +// -> tensor<128x2x128xf32, #blocked4> +// %t = tt.trans %r {order = array} +// -> tensor<128x128x2xf32, #blocked5> +// %lhs, %rhs = tt.split %t +// +// becomes +// %o0 = ttng.tmem_subslice %o { N = 0 } +// %lhs = ttng.tmem_load %o0 +// %o1 = ttng.tmem_subslice %o { N = 128 } +// %rhs = ttng.tmem_load %o1 +// +// and if %lhs / %rhs are split again through the same reshape->trans->split +// pattern, the transformation is can match again so that each further +// split is materialised as an independent `ttng.tmem_subslice` / `ttng.tmem_load` +// pair. Consequently, a chain such as +// +// acc0, acc1 = split(permute(reshape(acc , ...))) +// acc00, acc01 = split(permute(reshape(acc0, ...))) +// acc10, acc11 = split(permute(reshape(acc1, ...))) +// +// is lowered to four independent TMEM loads operating on four disjoint +// subslices. +// +// clang-format on +// Strip away all intermediate ttg.convert_layout ops to reach the true +// producer. +static Value stripConvertLayout(Value v) { + while (auto cvt = v.getDefiningOp()) + v = cvt.getSrc(); + return v; +} + +class TMemSplitLoadPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SplitOp splitOp, + PatternRewriter &rewriter) const override { + // ----------------------------------------------------------------------- + // Match the pattern: + // splitOp + // ^ | + // | +-- transOp(order = [0, 2, 1]) + // | ^ | + // | | +-- reshapeOp + // | | ^ | + // | | | +-- (maybe convert_layout) + // | | +-- tmemLoad + // ----------------------------------------------------------------------- + + // Starting from the split source, peel off convert_layouts if any. + Value src = stripConvertLayout(splitOp.getSrc()); + auto transOp = src.getDefiningOp(); + if (!transOp || transOp.getOrder() != ArrayRef({0, 2, 1})) + return failure(); + auto reshapeOp = transOp.getSrc().getDefiningOp(); + if (!reshapeOp) + return failure(); + + // Peel off convert_layouts *below* the reshape as well. This is required + // for the recursive case where the producer of the reshape is the result + // of an earlier optimisation pass (i.e. a convert_layout of a previous + // tmem_load). + Value reshapeSrc = stripConvertLayout(reshapeOp.getSrc()); + auto tmemLoad = reshapeSrc.getDefiningOp(); + if (!tmemLoad) + return failure(); + + auto shape = reshapeOp.getResult().getType().getShape(); + // Ensure M dimension is preserved by the reshape. + if (shape[0] != cast(reshapeSrc.getType()).getShape()[0]) + return failure(); + int mDim = getShapePerCTA(tmemLoad.getSrc().getType())[0]; + // TODO: enable other M cases. (the layout is a bit more complex). + if (mDim != 128) + return failure(); + int splitNSize = shape[2]; + if (splitNSize < 8) + return failure(); + + // Create the two TMEM subslices and their corresponding loads. + Value tmem = tmemLoad.getSrc(); // Could itself be a subslice. + int numWarps = ttg::lookupNumWarps(tmemLoad); + rewriter.setInsertionPoint(tmemLoad); + + auto createSliceLoad = + [&](int64_t nOffset) -> std::pair { + // Generate the subslice op. + Value subSlice = TMEMSubSliceOp::create(rewriter, tmemLoad.getLoc(), tmem, + nOffset, splitNSize); + + // Choose a layout compatible with the slice size. + gpu::MemDescType subSliceType = + cast(subSlice.getType()); + auto cgaLayout = + ttg::getCGALayout(splitOp.getOutLHS().getType().getEncoding()); + auto distLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + subSliceType, numWarps, cgaLayout); + + RankedTensorType newLoadType = + splitOp.getOutLHS().getType().cloneWithEncoding(distLayout); + + // Generate the load and convert_layout back to the original layout. + auto load = TMEMLoadOp::create(rewriter, tmemLoad.getLoc(), newLoadType, + subSlice); + auto cvt = ttg::ConvertLayoutOp::create( + rewriter, tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load); + + return {load, cvt}; + }; + + auto [load0, cvt0] = createSliceLoad(/*nOffset=*/0); + auto [load1, cvt1] = createSliceLoad(/*nOffset=*/splitNSize); + rewriter.replaceOp(splitOp, {cvt0, cvt1}); + return success(); + } +}; + +class TMemStoreJoinPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TMEMStoreOp storeOp, + PatternRewriter &b) const override { + // Look through layout conversions. + Value src = storeOp.getSrc(); + while (auto cvt = src.getDefiningOp()) { + src = cvt.getSrc(); + } + + // Only support joinin N dimension on the outer most. + auto reshapeOp = src.getDefiningOp(); + if (!reshapeOp) + return failure(); + auto shape = reshapeOp.getSrc().getType().getShape(); + if (reshapeOp.getType().getShape().front() != shape[0]) + return failure(); + auto transOp = reshapeOp.getSrc().getDefiningOp(); + if (!transOp || transOp.getOrder() != ArrayRef({0, 2, 1})) + return failure(); + auto joinOp = transOp.getSrc().getDefiningOp(); + if (!joinOp) + return failure(); + + // We found a tmem_store that is joined on the N dimension. We can split it + // into multiple tmem_stores. + int mDim = getShapePerCTA(storeOp.getDst().getType())[0]; + // TODO: enable other M cases. (the layout is a bit more complex). + if (mDim != 128) + return failure(); + int splitNSize = shape[2]; + if (splitNSize < 8) + return failure(); + + Location loc = storeOp.getLoc(); + Value tmem = storeOp.getDst(); + int numWarps = ttg::lookupNumWarps(storeOp); + Value truePred = arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + + auto cgaLayout = ttg::getCGALayout(joinOp.getLhs().getType().getEncoding()); + auto *ctx = joinOp.getContext(); + + auto createSlice = [&](TypedValue input, int offset) { + auto subSlice = TMEMSubSliceOp::create(b, loc, tmem, offset, splitNSize); + auto distLayout = nvidia_gpu::getDefaultLayoutForTmemLdSt( + subSlice.getType(), numWarps, cgaLayout); + auto newType = input.getType().cloneWithEncoding(distLayout); + auto cvt = ttg::ConvertLayoutOp::create(b, loc, newType, input); + auto store = + TMEMStoreOp::create(b, loc, subSlice, cvt.getResult(), truePred); + return store; + }; + + auto store0 = createSlice(joinOp.getLhs(), 0); + auto store1 = createSlice(joinOp.getRhs(), splitNSize); + b.eraseOp(storeOp); + return success(); + } +}; + +// Pick an optimized tmem load layout based on its users. When there are +// multiple warpgroups tmem_load results can be distirbuted along M or N across +// the warpgroups. By default distribute along N but when there is a reduction +// along N dimension we want to distribute along M instead to avoid having to +// reduce across warps. +class TMemLoadReducePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TMEMLoadOp tmemLoadOp, + PatternRewriter &rewriter) const override { + int numWarps = ttg::lookupNumWarps(tmemLoadOp); + // If there is only 1 warpgroup there is nothing to optimize as the layout + // is already reduction friendly. + if (numWarps != 8) + return failure(); + bool foundReductionAlongN = false; + auto filter = [&](Operation *op) { + if (isa(op) || op->hasTrait()) + return true; + if (auto reduce = dyn_cast(op)) { + foundReductionAlongN = reduce.getAxis() == 1; + } + return false; + }; + ForwardSliceOptions fwdOpt; + fwdOpt.filter = filter; + SetVector fwdSlices; + getForwardSlice(tmemLoadOp.getResult(), &fwdSlices, fwdOpt); + if (!foundReductionAlongN) + return failure(); + // Try to split along M dimension but follow the restrictions of TMEM: + // warp0 get M = 0, warp 1 gets M = 32, warp 2 gets M = 64, warp 3 gets + // M = 96 warp 4 gets M = 16, warp 5 gets M = 48, warp 6 gets M = 80, + // warp 7 gets M = 112 + RankedTensorType oldType = tmemLoadOp.getType(); + std::optional newLayout = + getTmemLoadLayoutSplitLongM(oldType, tmemLoadOp.getSrc().getType(), + numWarps); + if (!newLayout) + return failure(); + if (newLayout.value() == oldType.getEncoding()) + return failure(); + + auto newType = oldType.cloneWithEncoding(newLayout.value()); + tmemLoadOp.getResult().setType(newType); + OpBuilder builder(tmemLoadOp); + builder.setInsertionPointAfter(tmemLoadOp); + auto cvt = ttg::ConvertLayoutOp::create(builder, tmemLoadOp.getLoc(), + oldType, tmemLoadOp.getResult()); + tmemLoadOp.getResult().replaceAllUsesExcept(cvt.getResult(), cvt); + return success(); + } +}; + +// Optimize local_load -> tmem_store when the layout 16x256b allows better +// code generation for local_load lowering. +class TMemFromSharedMemPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TMEMStoreOp tmemStoreOp, + PatternRewriter &rewriter) const override { + auto tmemEnc = dyn_cast( + tmemStoreOp.getDst().getType().getEncoding()); + if (!tmemEnc) + return failure(); + int M = tmemEnc.getBlockM(); + int N = tmemEnc.getBlockN(); + int numWarps = ttg::lookupNumWarps(tmemStoreOp); + // Compute the alternative layout. + auto cgaLayout = + ttg::getCGALayout(tmemStoreOp.getSrc().getType().getEncoding()); + std::optional ll = + nvidia_gpu::getDistributedLayoutForTmemLdSt( + tmemStoreOp.getDst().getType(), TMemAccessAtom::I16x256b, numWarps, + cgaLayout); + if (!ll) + return failure(); + Attribute newEncoding = + gpu::LinearEncodingAttr::get(tmemStoreOp.getContext(), std::move(*ll)); + auto oldType = tmemStoreOp.getSrc().getType(); + auto newType = oldType.cloneWithEncoding(newEncoding); + if (newType == oldType) + return failure(); + + SetVector slice; + DenseMap layoutMap; + // Check how it may propagate up the SSA chain. + LogicalResult result = getConvertBackwardSlice( + tmemStoreOp.getSrcMutable(), slice, newEncoding, layoutMap); + if (result.failed()) + return failure(); + bool foundImprovedLoad = false; + for (Value v : slice) { + auto localLoad = v.getDefiningOp(); + if (!localLoad) + continue; + // 16x256b is optimized for 16bits load. + if (localLoad.getType().getElementType().getIntOrFloatBitWidth() != 16) + return failure(); + LinearLayout regLayout = gpu::toLinearLayout(localLoad.getType()); + LinearLayout smemLayout = + gpu::toLinearLayout(localLoad.getSrc().getType()); + int vecDim = + regLayout.invertAndCompose(smemLayout).getNumConsecutiveInOut(); + // If we find a 16bits load that cannot be vectorized use the alternative + // layout. + if (vecDim != 1) + return failure(); + foundImprovedLoad = true; + } + if (!foundImprovedLoad) + return failure(); + // Use the new layout and rely on RemoveLayoutConversions pass to propagate + // the convert_layout. + auto cvt = ttg::ConvertLayoutOp::create(rewriter, tmemStoreOp.getLoc(), + newType, tmemStoreOp.getSrc()); + rewriter.modifyOpInPlace(tmemStoreOp, [&]() { + tmemStoreOp.getSrcMutable().assign(cvt.getResult()); + }); + return success(); + } +}; + +// Optimize tmem_load -> local_store when the layout 16x256b allows better +// code generation for local_store lowering. +class TMemToSharedMemPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TMEMLoadOp tmemLoadOp, + PatternRewriter &rewriter) const override { + auto tmemEnc = dyn_cast( + tmemLoadOp.getSrc().getType().getEncoding()); + if (!tmemEnc) + return failure(); + int M = tmemEnc.getBlockM(); + int N = tmemEnc.getBlockN(); + int numWarps = ttg::lookupNumWarps(tmemLoadOp); + auto oldType = tmemLoadOp.getType(); + auto cgaLayout = ttg::getCGALayout(oldType.getEncoding()); + auto memType = cast(tmemLoadOp.getSrc().getType()); + // Compute the alternative layout. + auto ll = nvidia_gpu::getDistributedLayoutForTmemLdSt( + memType, TMemAccessAtom::I16x256b, numWarps, cgaLayout); + if (!ll) + return failure(); + Attribute newEncoding = + gpu::LinearEncodingAttr::get(tmemLoadOp.getContext(), std::move(*ll)); + auto newType = oldType.cloneWithEncoding(newEncoding); + if (newType == oldType) + return failure(); + + SetVector slice; + DenseMap layoutMap; + SmallVector> uses; + uses.push_back({tmemLoadOp.getResult(), newEncoding}); + bool foundImprovedStore = false; + llvm::DenseSet> visited; + while (!uses.empty()) { + auto [v, encoding] = uses.pop_back_val(); + if (!visited.insert({v, encoding}).second) + continue; + for (auto user : v.getUsers()) { + if (auto localStore = dyn_cast(user)) { + // Check if the store benefits from the new layout. + // 16x256b is optimized for 16bits load. + auto srcType = localStore.getSrc().getType(); + if (srcType.getElementType().getIntOrFloatBitWidth() >= 32) + continue; + LinearLayout regLayout = gpu::toLinearLayout(srcType); + LinearLayout smemLayout = + gpu::toLinearLayout(localStore.getDst().getType()); + int vecDim = + regLayout.invertAndCompose(smemLayout).getNumConsecutiveInOut(); + // If we find a 8 or 16bits store that cannot be vectorized use the + // alternative layout. + // TODO: we could refine the logic to make sure the new layout would + // help by allowing stmatrix if we can isolate good helpers. + if (vecDim != 1) + continue; + foundImprovedStore = true; + break; + } + // Don't iterate though control flow ops. + if (isa(user)) + continue; + Attribute userEncoding = inferDstEncoding(user, encoding); + if (!userEncoding) { + if (isa(user)) { + userEncoding = encoding; + } else { + continue; + } + } + for (auto result : user->getResults()) { + uses.push_back({result, userEncoding}); + } + } + } + if (!foundImprovedStore) + return failure(); + // Use the new layout and rely on RemoveLayoutConversions pass to propagate + // the convert_layout. + rewriter.modifyOpInPlace( + tmemLoadOp, [&]() { tmemLoadOp.getResult().setType(newType); }); + rewriter.setInsertionPointAfter(tmemLoadOp); + auto cvt = ttg::ConvertLayoutOp::create(rewriter, tmemLoadOp.getLoc(), + oldType, tmemLoadOp.getResult()); + rewriter.replaceAllUsesExcept(tmemLoadOp.getResult(), cvt, cvt); + return success(); + } +}; + +} // anonymous namespace + +class TritonNvidiaGPUOptimizeTMemLayoutsPass + : public impl::TritonNvidiaGPUOptimizeTMemLayoutsPassBase< + TritonNvidiaGPUOptimizeTMemLayoutsPass> { +public: + using BaseT = TritonNvidiaGPUOptimizeTMemLayoutsPassBase< + TritonNvidiaGPUOptimizeTMemLayoutsPass>; + using BaseT::BaseT; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns + .add(context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp new file mode 100644 index 0000000000..cae9d6227d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -0,0 +1,1038 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONGPUPLANCTAPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +// TODO: use ConvertLayoutOp +using CastOp = ::mlir::UnrealizedConversionCastOp; + +unsigned getNumUsers(Value value) { + return std::distance(value.user_begin(), value.user_end()); +} + +Type replaceLayout(const Type &type, const Attribute &newLayout) { + Type curType = type; + auto ptrTy = dyn_cast(curType); + if (ptrTy) + curType = ptrTy.getPointeeType(); + if (auto tensorTy = dyn_cast(curType)) + curType = tensorTy.cloneWithEncoding(newLayout); + if (ptrTy) + curType = triton::PointerType::get(curType, ptrTy.getAddressSpace()); + return curType; +} + +ttg::DistributedEncodingTrait +replaceCGALayout(ttg::DistributedEncodingTrait layout, + llvm::ArrayRef shape, int numWarps, + ttg::CGAEncodingAttr newCGALayout) { + if (auto blockedLayout = mlir::dyn_cast(layout)) { + return ttg::BlockedEncodingAttr::get( + layout.getContext(), shape, blockedLayout.getSizePerThread(), + blockedLayout.getOrder(), numWarps, 32, newCGALayout); + } else if (auto sliceLayout = + mlir::dyn_cast(layout)) { + return ttg::SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCGALayout(sliceLayout.getParent(), shape, numWarps, + newCGALayout)); + } else { + // Other layouts are generated by passes after PlanCTAPass + llvm::report_fatal_error("replaceCGALayout not implemented"); + return layout; + } +} + +class CTAPlanner { +public: + CTAPlanner(); + + void run(triton::FuncOp &funcOp); + +private: + CastOp markBackward(CastOp cast) const; + CastOp markForward(CastOp cast) const; + bool isBackward(CastOp cast) const; + bool isForward(CastOp cast) const; + + bool processDot(triton::FuncOp &funcOp); + bool processReduce(triton::FuncOp &funcOp); + void processStoreLikeOps(triton::FuncOp &funcOp); + + bool propagate(CastOp cast); + bool propagateBackward(CastOp cast); + bool propagateForward(CastOp cast); + + void eraseCastOp(CastOp cast); + void eraseCastOpFromQueue(CastOp cast); + void eraseCastOpsFromQueue(llvm::ArrayRef casts); + + void insertCasts(Operation *op, llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts); + void eliminateAdjacentCasts(CastOp cast0, CastOp cast1); + + bool isLoadStoreOp(Operation *op) const; + bool processLoadStore(Operation *op, Attribute layout); + + bool isElementwiseOp(Operation *op) const; + bool processElementwise(Operation *op, Attribute layout); + + bool processConstant(arith::ConstantOp constant, Attribute layout); + bool processSplat(triton::SplatOp splat, Attribute layout); + bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout); + bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout); + + bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout); + bool processExpandDimsBackward(triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newResultLayout); + bool processExpandDimsForward(triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newSrcLayout); + + bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + + bool processIfOp(scf::IfOp ifOp, int index, const Type &newType); + bool processForOp(scf::ForOp forOp, int index, const Type &newType); + + bool processIfOpBackward(scf::IfOp ifOp, CastOp cast); + bool processForOpBackward(scf::ForOp forOp, CastOp cast); + bool processBlockArgBackward(BlockArgument arg, CastOp cast); + bool processForOpForward(scf::ForOp forOp, CastOp cast); + bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast); + + bool processOpFallback(Operation *op); + + bool processMultiUsersBackward(Value input, CastOp cast); + bool processMultiUsersForward(Value output, CastOp cast); + + void markTiled(); + + unsigned step; + unsigned stepUnchanged; + bool tiled; + std::queue queue; +}; + +CTAPlanner::CTAPlanner() : step(0), stepUnchanged(0), tiled(false) {} + +void CTAPlanner::run(triton::FuncOp &funcOp) { + static const unsigned maxSteps = 10000; + + auto nextStep = [&]() { + ++step; + assert(step < maxSteps && "Maximum number of steps exceeded"); + }; + + processDot(funcOp); + nextStep(); + + processReduce(funcOp); + nextStep(); + + if (!tiled) { + processStoreLikeOps(funcOp); + nextStep(); + } + + while (!queue.empty()) { + CastOp cast = queue.front(); + queue.pop(); + bool changed = propagate(cast); + if (changed) { + stepUnchanged = 0; + } else { + queue.push(cast); + ++stepUnchanged; + } + nextStep(); + } +} + +CastOp CTAPlanner::markBackward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward")); + return cast; +} + +CastOp CTAPlanner::markForward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward")); + return cast; +} + +bool CTAPlanner::isBackward(CastOp cast) const { + return cast->getAttrOfType("direction") == "backward"; +} + +bool CTAPlanner::isForward(CastOp cast) const { + return cast->getAttrOfType("direction") == "forward"; +} + +void CTAPlanner::markTiled() { + assert(!tiled && "CTA tiling is already determined"); + tiled = true; +} + +bool CTAPlanner::processDot(triton::FuncOp &funcOp) { + // TODO: This is a naive implementation and should be refactored + auto getCTATiling = [](int64_t M, int64_t N, int64_t K, + unsigned numCTAs) -> std::pair { + // prefer a larger chunk size, at most 128; first assign splitM. + unsigned chunk_m = 128; + auto isLegal = [](unsigned chunk) { return chunk >= 64; }; + unsigned splitM, splitN; + for (; isLegal(chunk_m); chunk_m /= 2) { + splitM = std::clamp(M / chunk_m, 1, numCTAs); + splitN = numCTAs / splitM; + if (isLegal(N / splitN)) // chunk_n; + break; + } + return {splitM, splitN}; + }; + + funcOp.walk([&](triton::DotOp dot) { + MLIRContext *ctx = dot.getContext(); + + auto aTy = cast(dot.getA().getType()); + auto bTy = cast(dot.getB().getType()); + auto dTy = cast(dot.getD().getType()); + + assert(isa(aTy.getEncoding()) && + isa(bTy.getEncoding()) && + isa(dTy.getEncoding()) && + "PlanCTAPass should follow immediately after CoalescePass"); + + auto aLayout = cast(aTy.getEncoding()); + auto bLayout = cast(bTy.getEncoding()); + auto dLayout = cast(dTy.getEncoding()); + + unsigned M = dTy.getShape()[0]; + unsigned N = dTy.getShape()[1]; + unsigned K = aTy.getShape()[1]; + + unsigned splitM, splitN; + std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout)); + // FIXME: Should consider IR with more than one DotOps + markTiled(); + + OpBuilder builder(dot); + auto numThreads = ttg::lookupThreadsPerWarp(builder); + auto numWarps = ttg::lookupNumWarps(dot); + + auto newCGALayout = ttg::CGAEncodingAttr::fromSplitParams( + ctx, {splitM, splitN}, {splitM, splitN}, {1, 0}); + auto newDLayout = ttg::BlockedEncodingAttr::get( + ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(), + numWarps, numThreads, newCGALayout); + auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(), + newDLayout, 0); + auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(), + newDLayout, 0); + + insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout}, + {newDLayout}); + }); + + return true; +} + +bool CTAPlanner::processReduce(triton::FuncOp &funcOp) { + ModuleOp mod = funcOp->getParentOfType(); + unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + funcOp.walk([&](triton::ReduceOp reduce) { + MLIRContext *context = reduce.getContext(); + Value src = reduce.getOperands()[0]; + unsigned axis = reduce.getAxis(); + + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto srcLayout = srcTy.getEncoding(); + + auto rank = srcShape.size(); + auto order = ttg::getOrder(srcTy); + auto sizePerThread = ttg::getContigPerThread(srcTy); + auto CTAOrder = ttg::getCTAOrder(srcLayout); + + llvm::SmallVector CTAsPerCGA(rank, 0); + unsigned remainingCTAs = numCTAs; + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim == axis) { + CTAsPerCGA[dim] = 1; + } else { + CTAsPerCGA[dim] = std::min(srcShape[dim] / sizePerThread[dim], + remainingCTAs); + remainingCTAs /= CTAsPerCGA[dim]; + } + } + + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim != axis) { + CTAsPerCGA[dim] *= remainingCTAs; + break; + } + } + + llvm::SmallVector CTASplitNum = CTAsPerCGA; + + // If numCTAs > 1 and the only dimension is the reduced dimension, after the + // above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set + // CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no + // cross-CTA reduction is required, although this will introduce duplicated + // calculation + if (remainingCTAs > 0) + CTAsPerCGA[order[rank - 1]] *= remainingCTAs; + + auto numWarps = ttg::lookupNumWarps(reduce); + auto CGALayout = ttg::CGAEncodingAttr::fromSplitParams( + context, CTAsPerCGA, CTASplitNum, CTAOrder); + if (!tiled) + markTiled(); + auto newSrcLayout = + replaceCGALayout(cast(srcLayout), + srcShape, numWarps, CGALayout); + auto newResultLayout = + ttg::SliceEncodingAttr::get(context, axis, newSrcLayout); + unsigned numOperands = reduce.getNumOperands(); + SmallVector newSrcLayoutVec(numOperands, newSrcLayout); + SmallVector newResultLayoutVec(numOperands, newResultLayout); + + insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec); + }); + return true; +} + +void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) { + assert(!tiled && "CTA tiling is already determined"); + + llvm::SmallVector stores; + funcOp.walk([&](Operation *op) { + if (llvm::isa(op)) + stores.push_back(op); + }); + assert(stores.size() > 0 && "Cannot find store-like ops"); + auto numWarps = ttg::lookupNumWarps(funcOp); + + ttg::CGAEncodingAttr CGALayout; + for (Operation *store : stores) { + auto val = [store]() -> Value { + if (auto descStore = + dyn_cast(store)) + return descStore.getSrc(); + return store->getOperand(0); + }(); + if (auto tensorTy = dyn_cast(val.getType())) { + if (!tiled) { + // Use CTA tiling of the first store-like op as global CTA tiling + CGALayout = ttg::getCGALayout(tensorTy.getEncoding()); + markTiled(); + } + auto newLayout = replaceCGALayout( + cast(tensorTy.getEncoding()), + tensorTy.getShape(), numWarps, CGALayout); + processElementwise(store, newLayout); + } + } + + if (!tiled) + markTiled(); +} + +bool CTAPlanner::propagate(CastOp cast) { + return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast); +} + +bool CTAPlanner::propagateBackward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(input); + if (numUsers == 0) { + llvm::report_fatal_error("Unreachable branch"); + return false; + } else if (numUsers == 1) { + Type outTy = output.getType(); + if (auto ptrTy = dyn_cast(outTy)) + outTy = ptrTy.getPointeeType(); + auto layout = mlir::cast( + mlir::cast(outTy).getEncoding()); + Operation *op = input.getDefiningOp(); + if (op == nullptr) { + assert(isa(input) && + "Unexpected Value without defining op"); + processBlockArgBackward(llvm::cast(input), cast); + } else if (auto prevCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(prevCast, cast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (auto constant = llvm::dyn_cast(op)) { + processConstant(constant, layout); + } else if (auto splat = llvm::dyn_cast(op)) { + processSplat(splat, layout); + } else if (auto makeRange = llvm::dyn_cast(op)) { + processMakeRange(makeRange, layout); + } else if (auto makeTensorPtr = + llvm::dyn_cast(op)) { + processMakeTensorPtr(makeTensorPtr, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto broadcast = llvm::dyn_cast(op)) { + processBroadcast(broadcast, layout); + } else if (auto expandDims = llvm::dyn_cast(op)) { + processExpandDimsBackward(expandDims, layout); + } else if (auto ifOp = llvm::dyn_cast(op)) { + processIfOpBackward(ifOp, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpBackward(forOp, cast); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutBackward(convertLayout, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + return true; + } else { + return processMultiUsersBackward(input, cast); + } +} + +bool CTAPlanner::propagateForward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(output); + if (numUsers == 0) { + cast.erase(); + } else if (numUsers == 1) { + Type inTy = input.getType(); + if (auto ptrTy = dyn_cast(inTy)) + inTy = ptrTy.getPointeeType(); + Attribute layout = mlir::cast(inTy).getEncoding(); + Operation *op = *output.user_begin(); + if (auto nextCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(cast, nextCast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutForward(convertLayout, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpForward(forOp, cast); + } else if (auto yieldOp = llvm::dyn_cast(op)) { + processYieldOpForward(yieldOp, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + } else { + processMultiUsersForward(output, cast); + } + return true; +} + +void CTAPlanner::eraseCastOp(CastOp cast) { + Value output = cast.getResult(0); + assert(getNumUsers(output) == 0 && + "Cannot erase CastOp because it is still in use"); + cast.erase(); +} + +void CTAPlanner::eraseCastOpFromQueue(CastOp cast) { + eraseCastOpsFromQueue({cast}); +} + +void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef casts) { + llvm::DenseSet erased; + for (CastOp cast : casts) { + eraseCastOp(cast); + erased.insert(cast); + } + + decltype(queue) tempQueue; + std::swap(queue, tempQueue); + + // This is only a naive implementation. Should refactor with linked-list. + while (!tempQueue.empty()) { + auto cast = tempQueue.front(); + tempQueue.pop(); + if (!erased.contains(cast)) + queue.push(cast); + } +} + +void CTAPlanner::insertCasts(Operation *op, + llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts) { + assert(op->getNumOperands() == newOperandLayouts.size() && + "NumOperands mismatched"); + assert(op->getNumResults() == newResultLayouts.size() && + "NumResults mismatched"); + + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + operandTy = replaceLayout(operandTy, newOperandLayouts[i]); + auto cast = + markBackward(CastOp::create(builder, loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + resultTy = replaceLayout(resultTy, newResultLayouts[i]); + auto cast = + markForward(CastOp::create(builder, loc, result.getType(), result)); + result.setType(resultTy); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } +} + +void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) { + assert(cast0.getResult(0) == cast1.getOperand(0) && + "The two casts are not adjacent"); + assert(isForward(cast0) && isBackward(cast1) && + "Expected pattern of adjacent casts: forward + backward"); + + Value input = cast0.getOperand(0); + Value output = cast1.getResult(0); + + if (input.getType() == output.getType()) { + output.replaceAllUsesWith(input); + eraseCastOpsFromQueue({cast1, cast0}); + } else { + OpBuilder builder(cast1.getOperation()); + auto cvt = ttg::ConvertLayoutOp::create(builder, cast1.getLoc(), + output.getType(), input); + output.replaceAllUsesWith(cvt.getResult()); + eraseCastOpsFromQueue({cast1, cast0}); + } +} + +bool CTAPlanner::isLoadStoreOp(Operation *op) const { + return llvm::isa(op); +} + +bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) { + // Special logic for: + // LoadOp -> SliceLayout + // Transform to: + // LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] > 1) { + // Find an input or output value of LoadOp or StoreOp to get its layout + Value val = + op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0); + Attribute originalLayout = + cast(val.getType()).getEncoding(); + // Insert casts using originalLayout. Adjacent casts will be eliminated + // and generate a ConvertLayoutOp with DSmem access + return processLoadStore(op, originalLayout); + } + } + + auto CGALayout = ttg::getCGALayout(layout); + auto numWarps = ttg::lookupNumWarps(op); + + llvm::SmallVector newOperandLayouts; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto type = op->getOperand(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = dyn_cast(type); + if (!tensorTy) { + newOperandLayouts.push_back(Attribute()); + continue; + } + auto oldLayout = + cast(tensorTy.getEncoding()); + auto newLayout = + replaceCGALayout(oldLayout, tensorTy.getShape(), numWarps, CGALayout); + newOperandLayouts.push_back(newLayout); + } + + llvm::SmallVector newResultLayouts; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto type = op->getResult(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = cast(type); + auto oldLayout = + cast(tensorTy.getEncoding()); + auto newLayout = + replaceCGALayout(oldLayout, tensorTy.getShape(), numWarps, CGALayout); + newResultLayouts.push_back(newLayout); + } + + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::isElementwiseOp(Operation *op) const { + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (auto externElementwiseOp = dyn_cast(op)) + return externElementwiseOp.getPure(); + if (llvm::isa(op)) + return true; + return false; +} + +bool CTAPlanner::processElementwise(Operation *op, Attribute layout) { + llvm::SmallVector newOperandLayouts(op->getNumOperands(), layout); + llvm::SmallVector newResultLayouts(op->getNumResults(), layout); + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) { + if (auto tensorTy = dyn_cast(constant.getType())) { + if (auto attr = dyn_cast(constant.getValue())) { + + auto newTensorTy = tensorTy.cloneWithEncoding(layout); + constant.setValueAttr( + SplatElementsAttr::get(newTensorTy, attr.getSplatValue())); + } + } + insertCasts(constant.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) { + insertCasts(splat.getOperation(), {{}}, {layout}); + return true; +} + +bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange, + Attribute layout) { + insertCasts(makeRange.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout) { + // All inputs of `makeTensorPtr` are scalar types + llvm::SmallVector dummyInAttrs(makeTensorPtr.getNumOperands(), {}); + insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout}); + return true; +} + +bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast, + Attribute layout) { + insertCasts(broadcast.getOperation(), {layout}, {layout}); + return true; +} + +bool CTAPlanner::processExpandDimsBackward( + triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newResultLayout) { + auto newSrcLayout = ttg::SliceEncodingAttr::get( + newResultLayout.getContext(), expandDims.getAxis(), newResultLayout); + insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout}); + return true; +} + +bool CTAPlanner::processExpandDimsForward( + triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newSrcLayout) { + llvm::report_fatal_error("processExpandDimsForward not implemented yet"); + return true; +} + +bool CTAPlanner::processConvertLayoutBackward( + ttg::ConvertLayoutOp convertLayout, CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(result) == 1 && + "Expect to call processMultiUsersBackward first"); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(src) == 1 && + "Expect to call processMultiUsersForward first"); + src.setType(result.getType()); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) { + // Check index + assert(index < ifOp.getNumResults() && "Invalid result index of IfOp"); + assert(index < ifOp.thenYield().getNumOperands() && + "Invalid operand index of YieldOp"); + assert(index < ifOp.elseYield().getNumOperands() && + "Invalid operand index of YieldOp"); + + Location loc = ifOp.getLoc(); + OpBuilder builder(ifOp.getContext()); + + // Insert forward cast after ifOp + Value result = ifOp.getResult(index); + builder.setInsertionPointAfter(ifOp.getOperation()); + auto newCast = + markForward(CastOp::create(builder, loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward casts before yield + for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) { + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(CastOp::create(builder, loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + } + + return true; +} + +bool CTAPlanner::processForOp(scf::ForOp forOp, int index, + const Type &newType) { + Block *body = forOp.getBody(); + auto yield = llvm::cast(forOp.getBody()->getTerminator()); + + // Check index + assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() && + "Invalid operand index of ForOp"); + assert(index + forOp.getNumInductionVars() < body->getNumArguments() && + "Invalid block arg index of ForOp"); + assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp"); + assert(index < forOp.getNumResults() && "Invalid result index of IfOp"); + + Location loc = forOp.getLoc(); + OpBuilder builder(forOp.getContext()); + + // Insert backward cast before forOp + OpOperand &operand = + forOp->getOpOperand(index + forOp.getNumControlOperands()); + builder.setInsertionPoint(forOp.getOperation()); + auto newCast = + markBackward(CastOp::create(builder, loc, newType, operand.get())); + operand.set(newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after block arg + Value arg = body->getArgument(index + forOp.getNumInductionVars()); + builder.setInsertionPointToStart(body); + newCast = markForward(CastOp::create(builder, loc, arg.getType(), arg)); + arg.setType(newType); + arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward cast before yield + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(CastOp::create(builder, loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after forOp + Value result = forOp.getResult(index); + builder.setInsertionPointAfter(forOp.getOperation()); + newCast = markForward(CastOp::create(builder, loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + return true; +} + +int findResultIndex(Operation *op, Value result) { + for (int i = 0; i < op->getNumResults(); ++i) + if (op->getResult(i) == result) + return i; + llvm::report_fatal_error("Invalid index of op result"); + return -1; +} + +bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) { + int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processIfOp(ifOp, index, newType); +} + +bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) { + int index = findResultIndex(forOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) { + if (auto forOp = llvm::dyn_cast(arg.getOwner()->getParentOp())) { + int index = int(arg.getArgNumber()) - forOp.getNumInductionVars(); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); + } else { + llvm::report_fatal_error("Unexpected parent op of block argument"); + return true; + } +} + +bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber() - + forOp.getNumControlOperands(); + auto newType = cast.getOperand(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber(); + auto newType = cast.getOperand(0).getType(); + if (auto ifOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processIfOp(ifOp, index, newType); + else if (auto forOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processForOp(forOp, index, newType); + else + llvm::report_fatal_error("Unexpected parent op of YieldOp"); + return true; +} + +bool CTAPlanner::processOpFallback(Operation *op) { + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + auto cast = + markBackward(CastOp::create(builder, loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + auto cast = markForward(CastOp::create(builder, loc, resultTy, result)); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } + + return true; +} + +bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) { + Location loc = input.getLoc(); + OpBuilder builder(input.getContext()); + + llvm::DenseMap> typeToIndices; + for (OpOperand &operand : input.getUses()) { + auto brotherCast = llvm::dyn_cast(operand.getOwner()); + if (!brotherCast) { + if (stepUnchanged <= queue.size()) + return false; + builder.setInsertionPoint(operand.getOwner()); + brotherCast = markBackward( + CastOp::create(builder, loc, cast.getResult(0).getType(), input)); + auto newCast = markForward(CastOp::create(builder, loc, input.getType(), + brotherCast.getResult(0))); + operand.set(newCast.getResult(0)); + queue.push(brotherCast); + queue.push(newCast); + } + auto type = brotherCast.getResult(0).getType(); + typeToIndices[type].push_back(brotherCast); + } + + bool first = true; + for (auto it : typeToIndices) { + Type &type = it.first; + llvm::SmallVector &casts = it.second; + Value newInput = input; + if (!first) { + if (Operation *defOp = input.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + Operation *clonedOp = builder.clone(*defOp); + newInput = clonedOp->getResult(0); + } else { + llvm::report_fatal_error("Layout conflict for block arg"); // TODO + return false; + } + } + first = false; + if (Operation *defOp = newInput.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + } else { + assert(isa(newInput) && + "Unexpected Value without defining op"); + builder.setInsertionPointToStart( + llvm::cast(newInput).getOwner()); + } + auto newCast = markBackward(CastOp::create(builder, loc, type, newInput)); + queue.push(newCast); + auto newResult = newCast.getResult(0); + for (CastOp &brotherCast : casts) { + brotherCast.getResult(0).replaceAllUsesWith(newResult); + eraseCastOpFromQueue(brotherCast); + } + } + return true; +} + +bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) { + Value castSrc = cast.getOperand(0); + + Location loc = cast.getLoc(); + OpBuilder builder(cast.getContext()); + builder.setInsertionPointAfter(cast.getOperation()); + + while (!castResult.use_empty()) { + auto newCast = markForward( + CastOp::create(builder, loc, castResult.getType(), castSrc)); + castResult.use_begin()->set(newCast.getResult(0)); + queue.push(newCast); + } + + eraseCastOp(cast); + return true; +} + +} // anonymous namespace + +struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // Skip PlanCTAPass when numCTAs == 1 + if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1) + return; + + mod.walk([&](triton::FuncOp funcOp) { + CTAPlanner planner; + planner.run(funcOp); + + // FIXME: Clone funcOp so that the IR change can be identified after + // PlanCTAPass. Without this, the change after PlanCTAPass will not be + // displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should + // be fixed later. + OpBuilder builder(funcOp); + builder.clone(*funcOp.getOperation()); + funcOp.erase(); + }); + } +}; + +std::unique_ptr createTritonNvidiaGPUPlanCTAPass() { + return std::make_unique(); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +/* TODO + * - Use ConvertLayoutOp instead of UnrealizedConversionCastOp. + * - Move PlanCTAPass to the front of CoalescePass. + * - Design better tiling strategy for DotOp and ReduceOp. + * - Consider cases where there are more than one DotOps. + * - Use better data structure for erasing CastOps from queue (linked list?). + * - Process eliminable CastOps in higher priority. + * - Fix the clone func bug in PlanCTAPass::runOnOperation. + * - Add some comments to introduce the overall idea of this pass. + * - Add some lit tests for this pass. + */ diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp new file mode 100644 index 0000000000..339e64e55b --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp @@ -0,0 +1,117 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUPROMOTELHSTOTMEMPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { +template +Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, gpu::MemDescType lhsTMEMType, + ttg::CGAEncodingAttr cgaLayout) { + int numWarps = ttg::lookupNumWarps(tcGen5MMAOp); + return nvidia_gpu::getDefaultLayoutForTmemLdSt(lhsTMEMType, numWarps, + cgaLayout); +} + +template class LHSToTMem : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MMAOpTy tcGen5MMAOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = tcGen5MMAOp->getContext(); + Location loc = tcGen5MMAOp.getLoc(); + auto lhs = tcGen5MMAOp.getA(); + auto localAllocOp = lhs.template getDefiningOp(); + if (!localAllocOp) + return failure(); + // Limit the liverange of the TMem allocations to single block. + if (localAllocOp->getParentRegion() != tcGen5MMAOp->getParentRegion()) + return failure(); + Value src = localAllocOp.getSrc(); + auto srcType = cast(src.getType()); + auto srcLayout = srcType.getEncoding(); + auto accTMemEncoding = dyn_cast( + tcGen5MMAOp.getD().getType().getEncoding()); + auto CTASplitNum = triton::gpu::getCGALayout(srcLayout).getCTASplitNum(); + // TMem encoding for A operand is the same as for D (Acc), but packed for + // bitwidth=16 + unsigned elemBitWidth = + lhs.getType().getElementType().getIntOrFloatBitWidth(); + // We don't currently support fp8 (not sure if we can) + if (elemBitWidth != 16 && elemBitWidth != 32) { + return failure(); + } + const unsigned colStride = 1; + auto aTMemEncoding = TensorMemoryEncodingAttr::get( + context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1], + colStride, CTASplitNum[0], CTASplitNum[1], + accTMemEncoding.getTwoCTAs()); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + ttg::MemDescType lhsMemDescType = ttg::MemDescType::get( + lhs.getType().getShape(), lhs.getType().getElementType(), aTMemEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + bool layoutTmemCompatible = + isDistributedLayoutTMemCompatible(tcGen5MMAOp, srcType, lhsMemDescType); + Attribute newLayout = srcLayout; + if (!layoutTmemCompatible) { + if (!comesFromLoadOrBlockArg(src) || + triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) { + newLayout = getLHSTMemLayout(tcGen5MMAOp, lhsMemDescType, + ttg::getCGALayout(srcType.getEncoding())); + } else { + return failure(); + } + } + rewriter.setInsertionPointAfter(localAllocOp); + if (newLayout != srcLayout) { + auto ty = cast(src.getType()); + auto newTy = ty.cloneWithEncoding(newLayout); + src = ttg::ConvertLayoutOp::create(rewriter, loc, newTy, src); + } + Value tMemAlloc = TMEMAllocOp::create(rewriter, loc, lhsMemDescType, src); + tcGen5MMAOp.getAMutable().assign(tMemAlloc); + return success(); + } +}; +} // namespace + +class TritonNvidiaGPUPromoteLHSToTMemPass + : public impl::TritonNvidiaGPUPromoteLHSToTMemPassBase< + TritonNvidiaGPUPromoteLHSToTMemPass> { +public: + using TritonNvidiaGPUPromoteLHSToTMemPassBase< + TritonNvidiaGPUPromoteLHSToTMemPass>:: + TritonNvidiaGPUPromoteLHSToTMemPassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet patterns(context); + patterns.add>(context); + patterns.add>(context); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/ProxyFenceInsertion.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/ProxyFenceInsertion.cpp new file mode 100644 index 0000000000..c89b392a03 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/ProxyFenceInsertion.cpp @@ -0,0 +1,193 @@ +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +//===----------------------------------------------------------------------===// +// +// On Hopper+, async proxy is separate from generic proxy, so when shared memory +// is the generic proxy to the async proxy we need to insert a fence to ensure +// memory consistency. +// This pass analyzes dependencies and will conservatively insert fences to +// avoid race conditions between proxies. Async proxy is defined here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#async-proxy +// +// This pass runs after shared memory allocation, to make sure we insert fences +// between ops accessing aliasing buffers if needed. +// +// We also run a fence insertion pass during optimization phase as it is easier +// to insert fences at optimial location based on structured control flow. +// +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONGPUPROXYFENCEINSERTION +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +bool isAsyncProxyWrite(Operation *op) { + return isa(op); +} + +Value getSmemDest(Operation *op) { + if (auto asyncTMACopyGlobalToLocalOp = + dyn_cast(op)) { + return asyncTMACopyGlobalToLocalOp.getResult(); + } + if (auto asyncTMAGatherOp = + dyn_cast(op)) { + return asyncTMAGatherOp.getResult(); + } + return Value(); +} + +bool isAsyncProxyRead(Operation *op) { + return isa(op); +} + +bool ignoreOpForProxyFence(Operation *op) { + return isAsyncProxyRead(op) || isAsyncProxyWrite(op) || + isa(op); +} + +bool filterFn(Operation *op, Operation *other, Allocation *allocation) { + return ignoreOpForProxyFence(other); +} + +//===----------------------------------------------------------------------===// +// Proxy Fence Analysis +//===----------------------------------------------------------------------===// +class ProxyFenceAnalysis : public MembarOrFenceAnalysis { + +public: + ProxyFenceAnalysis() = default; + explicit ProxyFenceAnalysis(Allocation *allocation, MembarFilterFn filter) + : MembarOrFenceAnalysis(allocation, filter) {} + +private: + /// Updates the BlockInfo operation based on the operation. + virtual void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) override; + + void insertFence(Operation *operation, OpBuilder *builder); +}; + +void ProxyFenceAnalysis::insertFence(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + triton::nvidia_gpu::FenceAsyncSharedOp::create(*builder, op->getLoc(), false); +} + +void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + if (isa(op)) { + // If the current op is a fence, we clear previous reads and writes + blockInfo->sync(); + return; + } + BlockInfo curBlockInfo; + BlockInfo proxyBlockInfo; + + auto scratchBufferId = Allocation::InvalidBufferId; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + // TODO: handle proxy read cases. Those are currently handled in + // FenceInsertionPass where it can generate better placement for + // the fence. But we should support a safe fallback here. + auto interval = allocation->getAllocatedInterval(bufferId); + auto slice = AllocationSlice(value, interval); + + if (isAsyncProxyWrite(op)) { + if (value == getSmemDest(op)) { + proxyBlockInfo.syncWriteSlices[slice].insert(op); + } + } else if (isa( + effectInstance.getEffect())) { + curBlockInfo.syncWriteSlices[slice].insert(op); + } else if (isa(effectInstance.getEffect())) { + curBlockInfo.syncReadSlices[slice].insert(op); + } + } + } + } + } + } + scratchBufferId = allocation->getBufferId(op); + } + + // Scratch buffer operations consist of a series of shared memory operations + // starting from a shared memory write, followed by a series of shared memory + // read/write operations, mark them as a read. + if (scratchBufferId != Allocation::InvalidBufferId) { + auto interval = allocation->getAllocatedInterval(scratchBufferId); + auto scratchSlice = AllocationSlice(interval); + curBlockInfo.syncReadSlices[scratchSlice].insert(op); + } + if (isAsyncProxyWrite(op) || isAsyncProxyRead(op)) { + if (proxyBlockInfo.isIntersected(*blockInfo, filter, allocation)) { + builder->setInsertionPoint(op); + insertFence(op, builder); + blockInfo->sync(); + } + } + + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace + +struct ProxyFenceInsertionPass + : public impl::TritonGPUProxyFenceInsertionBase { + +public: + using impl::TritonGPUProxyFenceInsertionBase< + ProxyFenceInsertionPass>::TritonGPUProxyFenceInsertionBase; + void runOnOperation() override { + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; + ModuleOp mod = getOperation(); + // This pass does not depend on the amount of shared memory allocated + // so we can use the default allocation analysis scratch size function + ModuleAllocation allocation(mod); + ModuleMembarOrFenceAnalysis analysis(&allocation, + filterFn); + analysis.run(); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/RemoveTMEMTokens.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/RemoveTMEMTokens.cpp new file mode 100644 index 0000000000..6c4120e979 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/RemoveTMEMTokens.cpp @@ -0,0 +1,85 @@ +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Pass/PassManager.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUREMOVETMEMTOKENSPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +void eraseResult(Operation *op, unsigned resultIdx, Value replacement) { + OperationState state(op->getLoc(), op->getName(), op->getOperands(), + op->getResultTypes(), op->getAttrs()); + state.types.erase(std::next(state.types.begin(), resultIdx)); + OpBuilder b(op); + + if (auto segmentSizes = + op->getAttrOfType("resultSegmentSizes")) { + // Update resultSegmentSizes attribute if it exists + SmallVector newSegmentSizes(segmentSizes.asArrayRef()); + int pos = 0; + for (auto &segmentSize : newSegmentSizes) { + if (pos == resultIdx) { + segmentSize = 0; + break; + } + pos += segmentSize; + } + state.attributes.set("resultSegmentSizes", + b.getDenseI32ArrayAttr(newSegmentSizes)); + } + Operation *newOp = b.create(state); + SmallVector replacements = newOp->getResults(); + replacements.insert(std::next(replacements.begin(), resultIdx), replacement); + op->replaceAllUsesWith(replacements); + op->erase(); +} + +void removeTMEMToken(Operation *op, Value dummy) { + if (auto mmaOp = dyn_cast(op)) { + mmaOp.getAccDepMutable().clear(); + if (mmaOp.getToken()) + eraseResult(mmaOp, 0, dummy); + } else if (auto store = dyn_cast(op)) { + store.getDepMutable().clear(); + if (store.getToken()) + eraseResult(store, 0, dummy); + } else if (auto alloc = dyn_cast(op)) { + if (alloc.getToken()) + eraseResult(alloc, 1, dummy); + } else if (auto load = dyn_cast(op)) { + load.getDepMutable().clear(); + if (load.getToken()) + eraseResult(load, 1, dummy); + } +} + +} // anonymous namespace + +class TritonNvidiaGPURemoveTMEMTokensPass + : public impl::TritonNvidiaGPURemoveTMEMTokensPassBase< + TritonNvidiaGPURemoveTMEMTokensPass> { +public: + using TritonNvidiaGPURemoveTMEMTokensPassBase:: + TritonNvidiaGPURemoveTMEMTokensPassBase; + + void runOnOperation() override { + for (auto func : getOperation().getOps()) { + auto b = OpBuilder::atBlockBegin(&func.getBody().front()); + // Placeholder value that will get DCE'd by the canonicalizer. + Value dummy = ub::PoisonOp::create( + b, func.getLoc(), b.getType()); + func.walk([&](Operation *op) { removeTMEMToken(op, dummy); }); + } + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp new file mode 100644 index 0000000000..9c7769c714 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -0,0 +1,204 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/Support/ErrorHandling.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +#define GEN_PASS_DEF_TRITONNVIDIAGPUTMALOWERINGPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +static void +lowerTMALoad(Operation *op, RankedTensorType tensorType, Value desc, + function_ref createLoad, + PatternRewriter &rewriter) { + MLIRContext *ctx = op->getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + auto loc = op->getLoc(); + auto encoding = getEncodingFromDescriptor(op, tensorType, desc); + gpu::MemDescType memDescType = gpu::MemDescType::get( + tensorType.getShape(), tensorType.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory=*/true); + auto alloc = + gpu::LocalAllocOp::create(rewriter, loc, memDescType).getResult(); + auto numCTAs = gpu::lookupNumCTAs(op); + auto barrierCGALayout = + gpu::CGAEncodingAttr::get1DLayout(tensorType.getContext(), numCTAs); + auto barrierEncoding = gpu::SwizzledSharedEncodingAttr::get( + tensorType.getContext(), 1, 1, 1, {0}, barrierCGALayout); + gpu::MemDescType barrierMemDescType = + gpu::MemDescType::get({numCTAs}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = + gpu::LocalAllocOp::create(rewriter, loc, barrierMemDescType); + InitBarrierOp::create(rewriter, loc, barrierAlloc, 1); + auto shapePerCTA = getShapePerCTA(encoding, tensorType.getShape()); + int sizeInBytes = product(shapePerCTA) * + tensorType.getElementType().getIntOrFloatBitWidth() / 8; + Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + triton::nvidia_gpu::BarrierExpectOp::create(rewriter, loc, barrierAlloc, + sizeInBytes, pred); + createLoad(desc, barrierAlloc, alloc, pred); + Value phase = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + WaitBarrierOp::create(rewriter, loc, barrierAlloc, phase); + InvalBarrierOp::create(rewriter, loc, barrierAlloc); + replaceUsesWithLocalLoad(rewriter, op->getResult(0), alloc); + op->erase(); +} + +class TMALoadLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DescriptorLoadOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, + Value pred) { + triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( + rewriter, op.getLoc(), desc, op.getIndices(), barrierAlloc, alloc, + pred); + }; + lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); + return success(); + } +}; + +struct TMAGatherLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DescriptorGatherOp op, + PatternRewriter &rewriter) const override { + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, + Value pred) { + triton::nvidia_gpu::AsyncTMAGatherOp::create( + rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(), + barrierAlloc, alloc, pred); + }; + lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); + return success(); + } +}; + +static void lowerTMAStore(Operation *op, mlir::TypedValue src, + Value desc, + function_ref createStore, + PatternRewriter &rewriter) { + MLIRContext *ctx = op->getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + auto loc = op->getLoc(); + auto tensorType = src.getType(); + auto encoding = getEncodingFromDescriptor(op, src.getType(), desc); + assert(isa(encoding)); + gpu::MemDescType memDescType = gpu::MemDescType::get( + tensorType.getShape(), tensorType.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory=*/false); + Value alloc = gpu::LocalAllocOp::create(rewriter, loc, memDescType, src); + triton::nvidia_gpu::FenceAsyncSharedOp::create(rewriter, loc, false); + createStore(desc, alloc); + triton::nvidia_gpu::TMAStoreWaitOp::create(rewriter, loc, 0); + rewriter.eraseOp(op); +} + +struct TMAStoreLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DescriptorStoreOp op, + PatternRewriter &rewriter) const override { + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create( + rewriter, op.getLoc(), desc, op.getIndices(), alloc); + }; + lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); + return success(); + } +}; + +struct TMAReduceLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DescriptorReduceOp op, + PatternRewriter &rewriter) const override { + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMAReduceOp::create( + rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc); + }; + lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); + return success(); + } +}; + +struct TMAScatterLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DescriptorScatterOp op, + PatternRewriter &rewriter) const override { + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc, + op.getXOffsets(), + op.getYOffset(), alloc); + }; + lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); + return success(); + } +}; + +class TMACreateDescLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MakeTensorDescOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto alloc = triton::gpu::GlobalScratchAllocOp::create( + rewriter, loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, + TMA_ALIGN); + if (failed(createTMADesc(alloc, op, rewriter))) { + return failure(); + } + TensormapFenceproxyAcquireOp::create(rewriter, loc, alloc.getResult()); + auto newDesc = ReinterpretTensorDescOp::create(rewriter, loc, op.getType(), + alloc.getResult()); + rewriter.replaceOp(op, newDesc); + return success(); + } +}; + +} // anonymous namespace + +class TritonNvidiaGPUTMALoweringPass + : public impl::TritonNvidiaGPUTMALoweringPassBase< + TritonNvidiaGPUTMALoweringPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add( + context); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp new file mode 100644 index 0000000000..034295db51 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -0,0 +1,290 @@ +#include +#include +#include + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir::triton::nvidia_gpu { + +ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout, + ArrayRef shape) { + auto rank = shape.size(); + if (cgaLayout.getRank() == rank) + return cgaLayout; + + auto ctx = cgaLayout.getContext(); + if (cgaLayout.getRank() > rank) { + auto ll = cgaLayout.getLinearLayout(); + // Broadcast over the first rankDiff dims + unsigned rankDiff = cgaLayout.getRank() - rank; + for (int i = 0; i < rankDiff; ++i) { + ll = removeStandardDim(ll, 0); + } + return ttg::CGAEncodingAttr::get(ctx, std::move(ll)); + } + // For rank-reducing loads, we need to rank-increase the CTA Layout + auto rankDiff = rank - cgaLayout.getRank(); + for (unsigned i = 0; i < rankDiff; ++i) { + assert(shape[i] == 1 && "Should only happen for rank-reducing loads"); + } + auto ll = cgaLayout.getLinearLayout(); + auto kBlock = *ll.getInDimNames().begin(); + auto standardOuts = standardOutDimNames(ctx, rank); + // Append to front + for (int i = cgaLayout.getRank(); i < rank; ++i) { + ll = LinearLayout::identity1D(1, kBlock, standardOuts[i]) * ll; + } + // Rename out dims to dim0..dimn-1 + auto dimSizes = ll.getOutDims(); + for (auto [i, dim] : llvm::enumerate(standardOuts)) { + dimSizes[i].first = dim; + } + ll = LinearLayout(ll.getBases(), dimSizes, false); + return ttg::CGAEncodingAttr::get(ctx, std::move(ll)); +} + +ttg::SharedEncodingTrait +updateEncodingForShape(Operation *op, ttg::SharedEncodingTrait encoding, + RankedTensorType tensorType) { + auto ctx = encoding.getContext(); + auto cgaLayout = ttg::getCGALayout(encoding); + if (auto nvmmaEnc = dyn_cast(encoding)) { + auto existingCga = nvmmaEnc.getCGALayout(); + if (!existingCga) + return nvmmaEnc; + + auto newCgaEnc = updateCGALayoutForShape(cgaLayout, tensorType.getShape()); + return ttg::NVMMASharedEncodingAttr::get( + ctx, nvmmaEnc.getSwizzlingByteWidth(), nvmmaEnc.getTransposed(), + nvmmaEnc.getElementBitWidth(), nvmmaEnc.getFp4Padded(), newCgaEnc); + } + if (auto swizEnc = dyn_cast(encoding)) { + auto existingCga = swizEnc.getCGALayout(); + if (!existingCga) + return swizEnc; + + auto rank = tensorType.getRank(); + auto oldOrder = swizEnc.getOrder(); + SmallVector order; + for (int i = 0; i + oldOrder.size() < rank; ++i) + order.push_back(rank - i - 1); + for (int i = 0; i < oldOrder.size(); ++i) { + // If it is a rank-reducing load, we need to drop the last dimensions. + if (oldOrder[i] >= rank) + continue; + order.push_back(oldOrder[i]); + } + auto newCgaEnc = updateCGALayoutForShape(cgaLayout, tensorType.getShape()); + return ttg::SwizzledSharedEncodingAttr::get( + ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(), + order, newCgaEnc); + } + + constexpr auto msg = "Internal Error: Unhandled tensor descriptor encoding"; + if (op) + op->emitError() << msg; + llvm::report_fatal_error(msg); +} + +ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op, + RankedTensorType tensorType, + Value desc) { + auto descBlockType = cast(desc.getType()).getBlockType(); + Attribute encoding = descBlockType.getEncoding(); + if (!encoding) { + constexpr auto msg = + "Internal Error: Tensor descriptor should have encoding set"; + if (op) + op->emitError() << msg; + llvm::report_fatal_error(msg); + } + auto sharedEnc = cast(encoding); + if (descBlockType.getShape() == tensorType.getShape()) + return sharedEnc; + + return updateEncodingForShape(op, sharedEnc, tensorType); +} + +FailureOr getTMASwizzleMode(Location loc, TensorDescType ty) { + auto encoding = ty.getBlockType().getEncoding(); + auto mmaEncoding = dyn_cast(encoding); + unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0; + if (!mmaEncoding) { + auto swizzledEnc = dyn_cast(encoding); + if (!swizzledEnc || swizzledEnc.getVec() != 1 || + swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) { + return emitError(loc) + << "unhandled shared memory layout for TMA descriptor: " + << encoding; + } + } + + bool fp4Padded = isFp4Padded(encoding); + if (fp4Padded && swizzleBytes != 128) { + return emitError(loc) << "fp4 padded operands (elem type .b4x16_p64) only " + "supports 128-byte swizzling, but got " + << swizzleBytes; + } + + int32_t swizzleMode = 0; + if (swizzleBytes == 128) { + swizzleMode = 3; + } else if (swizzleBytes == 64) { + swizzleMode = 2; + } else if (swizzleBytes == 32) { + swizzleMode = 1; + } else { + assert(swizzleBytes == 0); + } + return swizzleMode; +} + +enum TMA_ELEMENT_TYPES { + TMA_U8 = 0, + TMA_U16 = 1, + TMA_U32 = 2, + TMA_S32 = 3, + TMA_U64 = 4, + TMA_S64 = 5, + TMA_F16 = 6, + TMA_F32 = 7, + TMA_F32_FTZ = 8, + TMA_F64 = 9, + TMA_BF16 = 10, + TMA_TF32 = 11, + TMA_TF32_FTZ = 12, + TMA_B4X16 = 13, + TMA_B4X16_P64 = 14, + TMA_B6X16_P32 = 15, + TMA_B6P2X16 = 15, +}; + +FailureOr getTMAElementType(Location loc, TensorDescType ty) { + auto encoding = ty.getBlockType().getEncoding(); + auto mmaEncoding = dyn_cast(encoding); + bool fp4Padded = isFp4Padded(encoding); + + if (fp4Padded) + return TMA_B4X16_P64; + + auto elemTy = ty.getBlockType().getElementType(); + if (elemTy.isBF16()) { + return TMA_BF16; + } else if (elemTy.isF16()) { + return TMA_F16; + } else if (elemTy.isF32()) { + return TMA_F32; + } else if (elemTy.isF64()) { + return TMA_F64; + } + + auto elemSize = elemTy.getIntOrFloatBitWidth() / 8; + switch (elemSize) { + case 1: + return TMA_U8; + case 2: + return TMA_U16; + case 4: + return elemTy.isSignedInteger() ? TMA_S32 : TMA_U32; + case 8: + return elemTy.isSignedInteger() ? TMA_S64 : TMA_U64; + default: + break; + } + return emitError(loc) + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; +} + +LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op, + OpBuilder &builder) { + using namespace mlir; + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto mkI32Constant = [&](int32_t val) { + return arith::ConstantOp::create(builder, loc, builder.getI32Type(), + builder.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + auto encoding = op.getType().getBlockType().getEncoding(); + auto mmaEncoding = + llvm::dyn_cast_or_null(encoding); + bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded(); + + int paddingScale = fp4Padded ? 2 : 1; + auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape()); + auto blockShape = + getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false); + auto contigDimSize = blockShape.back(); + + llvm::SmallVector boxDim; + if (fp4Padded && contigDimSize != 128) { + return op->emitError( + "FP4 padded loads require 128 elements or more in the last dim"); + } + boxDim.push_back(mkI32Constant(contigDimSize)); + for (int k = shapePerCTA.size() - 2; k >= 0; --k) + boxDim.push_back(mkI32Constant(blockShape[k])); + + unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0; + if (!mmaEncoding) { + auto swizzledEnc = dyn_cast( + op.getType().getBlockType().getEncoding()); + if (!swizzledEnc || swizzledEnc.getVec() != 1 || + swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) { + op->emitError() << "Unhandled encoding type"; + return failure(); + } + } + + auto maybeSwizzleMode = getTMASwizzleMode(loc, op.getType()); + if (failed(maybeSwizzleMode)) + return failure(); + auto swizzleMode = *maybeSwizzleMode; + + Value elemSizeVal = arith::ConstantOp::create( + builder, loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); + + SmallVector globalDim(llvm::reverse(op.getShape())); + SmallVector globalStride; + for (int k = op.getStrides().size() - 2; k >= 0; --k) { + globalStride.push_back(op.getStrides()[k]); + } + + if (fp4Padded) { + // Convert number of bytes to number of mxfp4 elements + globalDim[0] = + arith::MulIOp::create(builder, loc, globalDim[0], mkI32Constant(2)); + } + + SmallVector elementStride(globalDim.size(), mkI32Constant(1)); + + for (int i = 0; i < globalStride.size(); ++i) + globalStride[i] = + arith::MulIOp::create(builder, loc, globalStride[i], elemSizeVal); + + auto elemTypeEnum = getTMAElementType(loc, op.getType()); + if (failed(elemTypeEnum)) + return failure(); + + auto fillMode = (op.getPadding() == triton::PaddingOption::PAD_NAN) ? 1 : 0; + + TensormapCreateOp::create( + builder, loc, + /*desc_ptr=*/tmaPtr, + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/globalDim, + /*global_stride=*/globalStride, + /*element_strides=*/elementStride, + /*elem_type*/ builder.getI32IntegerAttr(*elemTypeEnum), + /*interleave_layout*/ builder.getI32IntegerAttr(0), + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzleMode), + /*fill_mode=*/builder.getI32IntegerAttr(fillMode)); + return success(); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp new file mode 100644 index 0000000000..695c834063 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp @@ -0,0 +1,436 @@ +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Traits.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +namespace ttg = triton::gpu; + +#define GEN_PASS_DEF_TRITONTENSORMEMORYALLOCATIONPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +// Granularity of row allocations. +static constexpr int allocGranularity = 64; +struct TMemChunk { + int startRow; + int startCol; + int numCols; + int numRows; +}; + +// Use a simple bitmap to track memory usage. This is a slow but it allows us to +// handle 2D memory without extra algorithmic complexity. The number of +// allocations is expected to be small so the compile time is unlikely to be a +// problem. +struct MemoryBitMap { + MemoryBitMap() : elements(512 * kNumRows, false) {} + void free(const TMemChunk &chunk) { + for (int i = 0; i < chunk.numCols; i++) { + for (int j = 0; j < chunk.numRows; j++) { + setUsed(chunk.startRow + j, chunk.startCol + i, false); + } + } + } + void alloc(const TMemChunk &chunk) { + // Ensure the underlying data fits the allocation. + while ((chunk.startCol + chunk.numCols) * kNumRows >= elements.size()) + elements.resize(2 * elements.size(), false); + + for (int i = 0; i < chunk.numCols; i++) { + for (int j = 0; j < chunk.numRows; j++) { + setUsed(chunk.startRow + j, chunk.startCol + i, true); + } + } + } + + TMemChunk findFirstFit(TMemAllocation allocSize, + std::optional rowIdConstraint, + int columnAlignment) const { + int numRows = allocSize.numRows / allocGranularity; + assert(kNumRows - numRows >= 0); + assert(allocSize.numRows % allocGranularity == 0); + int startCol = 0; + while (1) { + // Skip to the next aligned address. + if (startCol % columnAlignment != 0) { + startCol = (startCol / columnAlignment + 1) * columnAlignment; + } + // Iterate over possible starting rows + for (int startRow = 0; startRow <= kNumRows - numRows; ++startRow) { + if (rowIdConstraint && *rowIdConstraint != startRow) + continue; + bool fits = true; + + // Check if the block starting at (startRow, startCol) is free + for (int i = 0; i < allocSize.numCols && fits; ++i) { + for (int j = 0; j < numRows; ++j) { + if (isUsed(startRow + j, startCol + i)) { + fits = false; + break; + } + } + } + + // If a suitable block is found, return it + if (fits) { + TMemChunk chunk; + chunk.startRow = startRow; + chunk.startCol = startCol; + chunk.numRows = numRows; + chunk.numCols = allocSize.numCols; + return chunk; + } + } + startCol++; + } + return TMemChunk(); + } + +private: + bool isUsed(int row, int col) const { + if (row + col * kNumRows >= elements.size()) + return false; + return elements[row + col * kNumRows]; + } + void setUsed(int row, int col, bool used) { + assert(row + col * kNumRows < elements.size()); + elements[row + col * kNumRows] = used; + } + + static constexpr int kNumRows = 2; + std::vector elements; +}; + +static Interval getLiveIntervals(Value value, Liveness &liveness, + DenseMap &operationId) { + auto liveOperations = liveness.resolveLiveness(value); + // Merge the alloc liverange with the liverange of any subview of the + // allocation. + SmallVector users(value.getUsers()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); + if (!isa(user)) + continue; + auto usersLivness = liveness.resolveLiveness(user->getResult(0)); + liveOperations.insert(liveOperations.end(), usersLivness.begin(), + usersLivness.end()); + users.append(user->getResult(0).getUsers().begin(), + user->getResult(0).getUsers().end()); + } + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); +} + +static void updateMap(MemoryBitMap &memoryMap, Interval liveInterval, + std::multimap &intervalLiverangeEnd) { + int start = liveInterval.start(); + // Add any dead liverange to the list of free intervals. + for (auto it = intervalLiverangeEnd.begin(); + it != intervalLiverangeEnd.end();) { + if (it->first > start) + break; + memoryMap.free(it->second); + it = intervalLiverangeEnd.erase(it); + } +} + +static TMemChunk allocFirstFit(MemoryBitMap &memoryMap, + TMemAllocation allocSize, + std::optional rowIdConstraint, + ArrayRef coexistingChunks, + int columnAlignment) { + // `coexistingChunks` are all the allocations that might need to be live at + // the same time as the current allocation plus what is known to be currently + // live. Union those allocations with a copy of the current memory map and use + // that to find the actual offsets. + MemoryBitMap mapForAlloc = memoryMap; + for (const TMemChunk &chunk : coexistingChunks) + mapForAlloc.alloc(chunk); + TMemChunk chunk = + mapForAlloc.findFirstFit(allocSize, rowIdConstraint, columnAlignment); + + // Mark this chunk as allocated in the actual memory map. + memoryMap.alloc(chunk); + return chunk; +} + +static SmallVector getAlloc(Value value) { + SmallVector allocs; + DenseSet seen; + SmallVector worklist{value}; + + while (!worklist.empty()) { + Value v = worklist.pop_back_val(); + if (!seen.insert(v).second) + continue; + + // Handle block arguments. + if (auto arg = dyn_cast(v)) { + Block *block = arg.getOwner(); + Operation *parentOp = block->getParentOp(); + + // Handle block with predecessors. + if (!block->isEntryBlock()) { + for (Block *pred : block->getPredecessors()) { + Operation *predOp = pred->getTerminator(); + auto br = dyn_cast(predOp); + if (!br) { + llvm::report_fatal_error("unhandled branch op: " + + predOp->getName().getStringRef()); + } + SmallVector operands(br->getNumOperands()); + auto it = llvm::find(br->getSuccessors(), block); + unsigned idx = std::distance(br->getSuccessors().begin(), it); + SuccessorOperands args = br.getSuccessorOperands(idx); + Value operand = + args.getForwardedOperands()[arg.getArgNumber() - + args.getProducedOperandCount()]; + worklist.push_back(operand); + } + continue; + } + + // Handle region entry arguments. + if (auto wsOp = dyn_cast(parentOp)) { + worklist.push_back(wsOp.getExplicitCaptures()[arg.getArgNumber()]); + } else if (auto forOp = dyn_cast(parentOp)) { + unsigned idx = arg.getArgNumber() - 1; + worklist.push_back(forOp.getYieldedValues()[idx]); + worklist.push_back(forOp.getInits()[idx]); + } else if (auto whileOp = dyn_cast(parentOp)) { + unsigned idx = arg.getArgNumber(); + if (arg.getParentRegion() == &whileOp.getAfter()) { + worklist.push_back(whileOp.getConditionOp().getArgs()[idx]); + } else { + worklist.push_back(whileOp.getYieldedValues()[idx]); + worklist.push_back(whileOp.getInits()[idx]); + } + } else { + llvm::report_fatal_error( + "unhandled parent op when looking for TMEM alloc: " + + parentOp->getName().getStringRef()); + } + continue; + } + + Operation *defOp = v.getDefiningOp(); + unsigned idx = cast(v).getResultNumber(); + if (isa(defOp)) { + allocs.push_back(defOp); + } else if (defOp->hasTrait()) { + worklist.push_back(defOp->getOperand(0)); + } else if (auto sliceOp = dyn_cast(defOp)) { + worklist.push_back(sliceOp.getSrc()); + } else if (auto selectOp = dyn_cast(defOp)) { + worklist.push_back(selectOp.getTrueValue()); + worklist.push_back(selectOp.getFalseValue()); + } else if (auto ifOp = dyn_cast(defOp)) { + worklist.push_back(ifOp.thenYield().getOperand(idx)); + worklist.push_back(ifOp.elseYield().getOperand(idx)); + } else if (auto forOp = dyn_cast(defOp)) { + worklist.push_back(forOp.getYieldedValues()[idx]); + worklist.push_back(forOp.getInits()[idx]); + } else if (auto whileOp = dyn_cast(defOp)) { + worklist.push_back(whileOp.getConditionOp().getArgs()[idx]); + } else { + llvm::report_fatal_error("unhandled op when looking for TMEM alloc: " + + defOp->getName().getStringRef()); + } + } + + return allocs; +} + +class RowIdConstraints { + llvm::EquivalenceClasses dependentAllocs; + llvm::SmallDenseMap rowIndex; + +public: + void joinOps(Operation *op1, Operation *op2) { + dependentAllocs.unionSets(op1, op2); + } + + std::optional getRowIdConstraint(Operation *op) { + auto it = dependentAllocs.findLeader(op); + if (it == dependentAllocs.member_end()) + return std::nullopt; + auto rowIt = rowIndex.find(*it); + if (rowIt == rowIndex.end()) + return std::nullopt; + return rowIt->second; + } + + void addConstraints(Operation *op, int rowId) { + auto it = dependentAllocs.findLeader(op); + if (it == dependentAllocs.member_end()) + return; + rowIndex[*it] = rowId; + } +}; + +static int +allocateTMem(Operation *parentOp, + DenseMap &offsets) { + SmallVector allocs; + DenseMap operationId; + RowIdConstraints rowIdConstraints; + parentOp->walk([&](Operation *op) { + operationId[op] = operationId.size(); + if (auto alloc = dyn_cast(op)) { + allocs.push_back(alloc); + } + if (auto mmaOp = dyn_cast(op)) { + if (isa(mmaOp.getA().getType().getEncoding())) { + TMemAllocation allocSize = getTmemAllocSizes(mmaOp.getA().getType()); + if (allocSize.numRows == 64) { + // HW restriction, the A alloc and accumulator needs to be in the same + // rows. + SmallVector lhsAllocs = getAlloc(mmaOp.getA()); + SmallVector accAllocs = getAlloc(mmaOp.getAccumulator()); + for (Operation *lhsAlloc : lhsAllocs) + for (Operation *accAlloc : accAllocs) + rowIdConstraints.joinOps(lhsAlloc, accAlloc); + } else { + // TODO: we need to handle cases where the format is blockM and we + // have multiple blocks. + assert((cast( + mmaOp.getA().getType().getEncoding()) + .getBlockM() != 64 && + cast( + mmaOp.getAccumulator().getType().getEncoding()) + .getBlockM() != 64) && + "interleaved layout with TMEM operand is not supported yet."); + } + } + } + }); + int totalMemorySize = 0; + MemoryBitMap memoryMap; + Liveness liveness(parentOp); + std::multimap intervalLiverangeEnd; + DenseMap allocChunks; + // Implement a linear scan first fit algorithm. We expect that fragmentation + // won't be a problem, if it is this should be revisited. + for (auto it = allocs.begin(), e = allocs.end(); it != e; ++it) { + TMEMAllocOp alloc = *it; + + // Find all allocations in code that may execute at the same time. Only look + // at processed allocations. + SmallVector coexistingChunks; + if (auto ws = alloc->getParentOfType()) { + for (auto prevIt = allocs.begin(); prevIt != it; ++prevIt) { + TMEMAllocOp prevAlloc = *prevIt; + auto prevWs = + prevAlloc->getParentOfType(); + if (prevWs && prevWs == ws && + alloc->getParentRegion() != prevAlloc->getParentRegion()) + coexistingChunks.push_back(allocChunks.at(prevAlloc)); + } + } + + Interval liveInterval = getLiveIntervals(alloc, liveness, operationId); + auto memDescType = alloc.getType(); + TMemAllocation allocSize = getTmemAllocSizes(memDescType); + updateMap(memoryMap, liveInterval, intervalLiverangeEnd); + + std::optional rowIdConstraint = + rowIdConstraints.getRowIdConstraint(alloc); + // TODO: clarify the alignment requirements for different allocations. For + // now enforce an alignment of 4 columns. + const int columnAlignment = 4; + TMemChunk chunkAllocated = + allocFirstFit(memoryMap, allocSize, rowIdConstraint, coexistingChunks, + columnAlignment); + allocChunks.insert({alloc, chunkAllocated}); + // currently naively constraint allocs based on the first one we find. + rowIdConstraints.addConstraints(alloc, chunkAllocated.startRow); + intervalLiverangeEnd.insert({liveInterval.end(), chunkAllocated}); + int colOffset = chunkAllocated.startCol; + int rowOffset = chunkAllocated.startRow * 16; + + alloc->setAttr( + "tensor_memory_col_offset", + IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32), + colOffset)); + alloc->setAttr( + "tensor_memory_row_offset", + IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32), + rowOffset)); + totalMemorySize = std::max(totalMemorySize, colOffset + allocSize.numCols); + } + return totalMemorySize; +} + +} // anonymous namespace + +class TritonTensorMemoryAllocationPass + : public impl::TritonTensorMemoryAllocationPassBase< + TritonTensorMemoryAllocationPass> { +public: + IntegerAttr getI32Attr(int32_t value) { + return Builder(&getContext()).getI32IntegerAttr(value); + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + DenseMap offsets; + // TODO: handle cases with multiple function with TMEMAllocOp. + int totalMemorySize = allocateTMem(mod, offsets); + + std::array possibleAllocations = {0, 32, 64, 128, 256, 512}; + // NOTE: if totalMemorySize > 512 we exceeded the maximum amount of tensor + // memory, but we let the compilation finish so that we can raise an + // exception in python for the auto-tuner. + if (totalMemorySize <= 512) { + for (int size : possibleAllocations) { + if (totalMemorySize <= size) { + totalMemorySize = size; + break; + } + } + } + if (totalMemorySize > 0) { + // We use a small smem allocation to get the tensor memory base address + // from tcgen05.alloc, ensure the block has at least 4 bytes of smem + int shared = 0; + if (auto sharedAttr = mod->getAttr("ttg.shared")) { + shared = cast(sharedAttr).getInt(); + } + if (shared < 4) { + mod->setAttr("ttg.shared", getI32Attr(4)); + } + } + mod->setAttr("ttg.tensor_memory_size", getI32Attr(totalMemorySize)); + } +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/CMakeLists.txt b/third_party/mthreads/lib/Target/CMakeLists.txt new file mode 100644 index 0000000000..39d31dc9b5 --- /dev/null +++ b/third_party/mthreads/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt b/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000..88a265cd01 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,31 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMDILocalVariable.cpp + LLVMIRBreakPhiStruct.cpp + LLVMDIUtils.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRNVVMToLLVM + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + TritonGPUToLLVM + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMDILocalVariable.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMDILocalVariable.cpp new file mode 100644 index 0000000000..b858cdec01 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMDILocalVariable.cpp @@ -0,0 +1,289 @@ +#include "lib/Target/LLVMIR/LLVMDIUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +// #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +//===----------------------------------------------------------------------===// +// This file implements a pass to add ... to LLVM operations, and ... +//===----------------------------------------------------------------------===// + +namespace mlir { +using namespace LLVMDIUtils; + +#define DEBUG_TYPE "name-preservation" + +#define GEN_PASS_DEF_LLVMDILOCALVARIABLE +#include "triton/Target/LLVMIR/Passes.h.inc" + +struct LLVMDILocalVariablePass + : public impl::LLVMDILocalVariableBase { + + void fuseDILocalVariable(Operation *op) { + if (op->getNumResults() == 0) { + return; + } + + MLIRContext *context = op->getContext(); + OpBuilder builder(context); + Location loc = op->getLoc(); + + // if the location is a NameLoc, a.k.a it defines a value, then insert a + // dbg-value intrinsic after the op + if (auto nameLoc = dyn_cast(loc)) { + Location childLoc = nameLoc.getChildLoc(); + StringAttr nameAttr = nameLoc.getName(); + + // also see reference of operation construction from + // mlir/lib/Target/LLVMIR/ModuleImport.cpp which translated llvm::Module + // into mlir::LLVM::Operation + + // TODO: Those instantiation using defult is necessary for first viable + // result, but no meaning for now + LLVM::DIFileAttr diFileAttr = + LLVM::DIFileAttr::get(context, "", ""); + + // Extracting type info into DITypeAttr + mlir::Type resultType = op->getResult(0).getType(); + if (isa(resultType)) { + // we cannot allow void type to be noted as data type, otherwise trigger + // later assertion fault + return; + } + LLVM::DITypeAttr diTypeAttr = convertType(context, resultType); + LLVM::DIFlags diFlags = LLVM::DIFlags::Zero; + + // LLVM Dialect to LLVM translation requires DILocalScope when + // DILocalVariable is present + LLVM::DILocalScopeAttr diLocalScopeAttr = + dyn_cast(diSubprogramAttr); + + // DILocalVariable of LLVM Dialect, which will be translated to LLVM IR's + // llvm::DILocalVariable + LLVM::DILocalVariableAttr diLocalVarAttr; + + // TODO: current parameter only for first viable result for now + diLocalVarAttr = LLVM::DILocalVariableAttr::get( + context, diLocalScopeAttr, nameAttr, diFileAttr, 0, 0, 0, diTypeAttr, + diFlags); + + LLVM::DIExpressionAttr diExprAttr = LLVM::DIExpressionAttr::get(context); + // Note: must set insertion point before calling create since it will + // automatically insert the op + builder.setInsertionPointAfter(op); + // a subclass of mlir::Value, which is the value defined by this operation + OpResult opResult = op->getResult(0); + // create and insert this call-dbg-value intrinsic after the op + Operation *dbgOp = LLVM::DbgValueOp::create(builder, childLoc, opResult, + diLocalVarAttr, diExprAttr); + } + } + + // Follow the same logic as LLVMDIScopePass to construct a subprogram scope + LLVM::DISubprogramAttr getDISubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (auto fusedSubprogramAttr = + loc->findInstanceOf>()) + return fusedSubprogramAttr.getMetadata(); + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::Full); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + llvm::SmallVector types; + mlir::DataLayout dl( + funcOp.getOperation()->getParentOfType()); + for (auto resTy : funcOp.getResultTypes()) { + LLVM::DITypeAttr tyAttr = convertType(context, resTy); + types.push_back(tyAttr); + } + // If no return type then add a null type as a place holder for that. + if (types.empty()) + types.push_back(mlir::LLVM::DINullTypeAttr::get(context)); + + // Only pointer type and scalar types are supported for now + OpBuilder builder(context); + for (auto [idx, inTy] : llvm::enumerate(funcOp.getArgumentTypes())) { + if (auto ptrTy = dyn_cast(inTy)) { + auto pointeeTy = + funcOp.getArgAttrOfType(idx, "tt.pointee_type"); + auto sizeInBits = dl.getTypeSizeInBits(ptrTy); + // If no valid pointee type for this function argument, skip it. + mlir::Type elTy = + pointeeTy ? pointeeTy.getValue() : builder.getNoneType(); + LLVM::DITypeAttr tyAttr = convertPtrType(context, ptrTy, elTy, dl); + types.push_back(tyAttr); + } else if (auto structTy = dyn_cast(inTy)) { + LLVM::DITypeAttr tyAttr = + convertStructType(context, structTy, fileAttr, dl, line); + types.push_back(tyAttr); + } else if (auto arrayTy = dyn_cast(inTy)) { + LLVM::DITypeAttr tyAttr = + convertArrayType(context, arrayTy, fileAttr, dl, line); + types.push_back(tyAttr); + } else { + // Here assume remaining inTys are only scalar types + assert(inTy.isIntOrFloat() && "Expected scalar types"); + LLVM::DITypeAttr tyAttr = convertType(context, inTy); + types.push_back(tyAttr); + } + } + + auto subroutineTypeAttr = LLVM::DISubroutineTypeAttr::get( + context, llvm::dwarf::DW_CC_normal, types); + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + + auto recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + auto id = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, recId, /*isRecSelf=*/true, id, compileUnitAttr, fileAttr, + funcNameAttr, funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line, + subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}, + /*annotations=*/{}); + + return subprogramAttr; + } + + // construct a subprogram of an operation by using its parent function's + // DISubprogramAttr construction + LLVM::DISubprogramAttr getDISubprogramAttr(Operation op) { + auto funcOp = op.getParentOfType(); + return getDISubprogramAttr(funcOp); + } + + LLVM::DISubprogramAttr + fuseFuncArgVariables(LLVM::LLVMFuncOp funcOp, + LLVM::DISubprogramAttr subprogramAttr) { + + MLIRContext *context = &getContext(); + OpBuilder builder(context); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + llvm::SmallVector retainedNodes; + + LLVM::DIFileAttr fileAttr = subprogramAttr.getFile(); + LLVM::DISubroutineTypeAttr subroutineTypeAttr = subprogramAttr.getType(); + int64_t line = subprogramAttr.getLine(); + auto localScopeAttr = dyn_cast(subprogramAttr); + auto diFlag = LLVM::DIFlags::Zero; + + // Extract function arguments and add them to retainedNodes: + // 0. Extract function argument types from subroutineTypeAttr + // 1. Create DILocalVariable and DebugValueOp for each arg + // 2. Add each arg as DILocalVariableAttr to retainedNodes + auto argTypeAttrs = subroutineTypeAttr.getTypes(); + unsigned resNum = funcOp.getNumResults() ? funcOp.getNumResults() : 1; + for (unsigned idx = resNum; idx < argTypeAttrs.size(); idx++) { + LLVM::DITypeAttr argTypeAttr = argTypeAttrs[idx]; + unsigned argIdx = idx - resNum; + BlockArgument arg = funcOp.getArgument(argIdx); + + Location argLoc = arg.getLoc(); + auto nameLoc = dyn_cast(argLoc); + if (!nameLoc) + continue; + Location childLoc = nameLoc.getChildLoc(); + StringAttr nameAttr = nameLoc.getName(); + + auto argVarAttr = LLVM::DILocalVariableAttr::get( + context, localScopeAttr, nameAttr, fileAttr, line, argIdx + 1, 0, + argTypeAttr, diFlag); + + auto exprAttr = LLVM::DIExpressionAttr::get(context); + (void)LLVM::DbgValueOp::create(builder, childLoc, arg, argVarAttr, + exprAttr); + + retainedNodes.push_back(argVarAttr); + } + + mlir::DistinctAttr recId = subprogramAttr.getRecId(); + mlir::DistinctAttr id = + mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + LLVM::DICompileUnitAttr compileUnitAttr = subprogramAttr.getCompileUnit(); + StringAttr funcNameAttr = subprogramAttr.getName(); + LLVM::DISubprogramFlags subprogramFlags = + subprogramAttr.getSubprogramFlags(); + subprogramAttr = LLVM::DISubprogramAttr::get( + context, recId, /*isRecSelf=*/false, id, compileUnitAttr, fileAttr, + funcNameAttr, funcNameAttr, fileAttr, line, line, subprogramFlags, + subroutineTypeAttr, retainedNodes, /*annotations=*/{}); + + Location loc = funcOp.getLoc(); + // Reset the subprogramAttr with retainedNodes to the funcOp + funcOp->setLoc(mlir::FusedLoc::get(context, {loc}, subprogramAttr)); + return subprogramAttr; + } + + // set it while traversing into a function + LLVM::DISubprogramAttr diSubprogramAttr; + + void runOnOperation() override { + Operation *op = getOperation(); + + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) { + auto funcOp = cast(op); + diSubprogramAttr = getDISubprogramAttr(funcOp); + diSubprogramAttr = fuseFuncArgVariables(funcOp, diSubprogramAttr); + } else { + fuseDILocalVariable(op); + } + }); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 0000000000..e2675bd42f --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,193 @@ +#include "lib/Target/LLVMIR/LLVMDIUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +namespace mlir { + +#define GEN_PASS_DEF_LLVMDISCOPE +#include "triton/Target/LLVMIR/Passes.h.inc" + +using namespace LLVMDIUtils; + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public impl::LLVMDIScopeBase { + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + bool extractDILocalVar = + triton::tools::getBoolEnv("LLVM_EXTRACT_DI_LOCAL_VARIABLES"); + bool disableLineInfo = + triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); + DistinctAttr recId; // Recursive ID to mark the DICompileUnitAttr and + // DISubprogramAttr that are recursively defined + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + recId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, + extractDILocalVar + ? LLVM::DIEmissionKind::Full + : LLVM::DIEmissionKind:: + LineTablesOnly); // DIEmissionKind::Full is required by + // emitting ptx with dbg-metadata + // (otherwise assertion fail) + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + llvm::SmallVector types; + mlir::DataLayout dl( + funcOp.getOperation()->getParentOfType()); + for (auto resTy : funcOp.getResultTypes()) { + LLVM::DITypeAttr tyAttr = convertType(context, resTy); + types.push_back(tyAttr); + } + // If no return type then add a null type as a place holder for that. + if (types.empty()) + types.push_back(mlir::LLVM::DINullTypeAttr::get(context)); + + // Only pointer type and scalar types are supported for now + OpBuilder builder(context); + for (auto [idx, inTy] : llvm::enumerate(funcOp.getArgumentTypes())) { + if (auto ptrTy = dyn_cast(inTy)) { + auto pointeeTy = + funcOp.getArgAttrOfType(idx, "tt.pointee_type"); + // If no valid pointee type for this function argument, use null type as + // unknown type. + mlir::Type elTy = + pointeeTy ? pointeeTy.getValue() : builder.getNoneType(); + LLVM::DITypeAttr tyAttr = convertPtrType(context, ptrTy, elTy, dl); + types.push_back(tyAttr); + } else if (auto structTy = dyn_cast(inTy)) { + LLVM::DITypeAttr tyAttr = + convertStructType(context, structTy, fileAttr, dl, line); + types.push_back(tyAttr); + } else if (auto arrayTy = dyn_cast(inTy)) { + LLVM::DITypeAttr tyAttr = + convertArrayType(context, arrayTy, fileAttr, dl, line); + types.push_back(tyAttr); + } else { + // Keep DI generation resilient when a backend introduces additional + // LLVM argument kinds (e.g. vectors/target-specific types). + LLVM::DITypeAttr tyAttr = convertType(context, inTy); + types.push_back(tyAttr); + } + } + + auto subroutineTypeAttr = LLVM::DISubroutineTypeAttr::get( + context, llvm::dwarf::DW_CC_normal, types); + + StringAttr funcNameAttr = funcOp.getNameAttr(); + + bool isRecSelf = !disableLineInfo && extractDILocalVar; + auto id = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, recId, isRecSelf, id, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, + /*line=*/line, /*scopeline=*/line, subprogramFlags, subroutineTypeAttr, + /*retainNodes=*/{}, /*annotations=*/{}); + + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + void setLexicalBlockFileAttr(Operation *op) { + Location opLoc = op->getLoc(); + if (!isa(opLoc)) + return; + + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + auto scopeAttr = + mlir::cast(funcOpLoc.getMetadata()); + + MLIRContext *ctx = op->getContext(); + std::function makeScoped = + [&](Location loc) -> Location { + if (auto cs = dyn_cast(loc)) { + Location newCallee = makeScoped(cs.getCallee()); + Location newCaller = makeScoped(cs.getCaller()); + return CallSiteLoc::get(newCallee, newCaller); + } + + // Build a DIFile for this leaf location + FileLineColLoc fileLine = extractFileLoc(loc, /*getCaller=*/false); + StringRef inputFilePath = fileLine.getFilename().getValue(); + LLVM::DIFileAttr fileAttr = + LLVM::DIFileAttr::get(ctx, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + + auto lexicalBlock = + LLVM::DILexicalBlockFileAttr::get(ctx, scopeAttr, fileAttr, + /*discriminator=*/0); + return FusedLoc::get(ctx, {loc}, lexicalBlock); + }; + + op->setLoc(makeScoped(opLoc)); + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.cpp new file mode 100644 index 0000000000..779d9c620d --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.cpp @@ -0,0 +1,159 @@ +#include "lib/Target/LLVMIR/LLVMDIUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" +#include "llvm/BinaryFormat/Dwarf.h" + +namespace mlir { + +// Note: mlir does not provided any built-in conversion from mlir::Type to +// mlir::LLVM::DITypeAttr +LLVM::DITypeAttr LLVMDIUtils::convertType(MLIRContext *context, + mlir::Type type) { + if (type.isInteger(1)) { + return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "bool"), + type.getIntOrFloatBitWidth(), + llvm::dwarf::DW_ATE_boolean); + } + if (type.isInteger()) { + return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "int"), + type.getIntOrFloatBitWidth(), + llvm::dwarf::DW_ATE_signed); + } else if (type.isF16()) { + return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "half"), + type.getIntOrFloatBitWidth(), + llvm::dwarf::DW_ATE_float); + } else if (type.isF32()) { + return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "float"), + type.getIntOrFloatBitWidth(), + llvm::dwarf::DW_ATE_float); + } else if (type.isF64()) { + return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "double"), + type.getIntOrFloatBitWidth(), + llvm::dwarf::DW_ATE_float); + } else if (mlir::isa(type)) { + if (auto vectorTypeSize = calcBitWidth(type); vectorTypeSize.has_value()) { + return LLVM::DIBasicTypeAttr::get( + context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "vector"), vectorTypeSize.value(), + llvm::dwarf::DW_ATE_float); + } else { + // TODO: falling back to unknown_type, perhaps theres a better way to + // handle when element type size is not determined + } + } + return LLVM::DIBasicTypeAttr::get( + context, llvm::dwarf::DW_TAG_base_type, + mlir::StringAttr::get(context, "unknown_type"), 0, + llvm::dwarf::DW_ATE_signed); +} + +LLVM::DITypeAttr LLVMDIUtils::convertPtrType(MLIRContext *context, + LLVM::LLVMPointerType pointerType, + mlir::Type pointeeType, + DataLayout datalayout) { + // LLVMPointerType does not include pointee info, need to pass from external + // source + unsigned addrSpace = pointerType.getAddressSpace(); + + unsigned sizeInBits = datalayout.getTypeSizeInBits(pointerType); + LLVM::DITypeAttr diElTypeAttr = convertType(context, pointeeType); + LLVM::DITypeAttr diTypeAttr = mlir::LLVM::DIDerivedTypeAttr::get( + context, llvm::dwarf::DW_TAG_pointer_type, + mlir::StringAttr::get(context, "pointer"), diElTypeAttr, sizeInBits, + /*alignInBits=*/0, /*offset=*/0, addrSpace, /*extra data=*/nullptr); + return diTypeAttr; +} + +LLVM::DITypeAttr LLVMDIUtils::convertStructType(MLIRContext *context, + LLVM::LLVMStructType structType, + LLVM::DIFileAttr fileAttr, + DataLayout datalayout, + int64_t line) { + + assert(!structType.isPacked() && !structType.isIdentified() && + "Only accepts NON-Packed and Literal struct type"); + + unsigned sizeInBits = datalayout.getTypeSizeInBits(structType); + SmallVector elTypes; + for (auto [idx, element] : llvm::enumerate(structType.getBody())) { + LLVM::DITypeAttr tyAttr = convertType(context, element); + elTypes.push_back(tyAttr); + } + + return LLVM::DICompositeTypeAttr::get( + context, llvm::dwarf::DW_TAG_structure_type, + mlir::StringAttr::get(context, "struct"), fileAttr, /*line=*/line, + /*scope=*/fileAttr, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, + sizeInBits, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elTypes); +} + +LLVM::DITypeAttr LLVMDIUtils::convertArrayType(MLIRContext *context, + LLVM::LLVMArrayType arrayType, + LLVM::DIFileAttr fileAttr, + DataLayout datalayout, + int64_t line) { + unsigned sizeInBits = datalayout.getTypeSizeInBits(arrayType); + + mlir::Type elementType = arrayType.getElementType(); + LLVM::DITypeAttr baseType = convertType(context, elementType); + SmallVector elTypes(arrayType.getNumElements(), + convertType(context, elementType)); + + return LLVM::DICompositeTypeAttr::get( + context, llvm::dwarf::DW_TAG_array_type, + mlir::StringAttr::get(context, "array"), fileAttr, /*line=*/line, + /*scope=*/fileAttr, /*baseType=*/baseType, mlir::LLVM::DIFlags::Zero, + sizeInBits, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elTypes); +} + +std::optional LLVMDIUtils::calcBitWidth(mlir::Type type) { + if (type.isIntOrFloat()) { + return type.getIntOrFloatBitWidth(); + } else if (mlir::isa(type)) { + auto vectorType = dyn_cast(type); + llvm::ArrayRef shape = vectorType.getShape(); + mlir::Type elementType = vectorType.getElementType(); + llvm::ArrayRef scalableDims = vectorType.getScalableDims(); + unsigned size = 1; + for (auto i : shape) { + size *= i; + } + + if (auto elementTypeSize = calcBitWidth(elementType); + elementTypeSize.has_value()) { + return size * elementTypeSize.value(); + } + } + + return std::nullopt; +} + +/// Attempt to extract a filename for the given loc. +FileLineColLoc LLVMDIUtils::extractFileLoc(Location loc, bool getCaller) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return getCaller ? extractFileLoc(callerLoc.getCaller()) + : extractFileLoc(callerLoc.getCallee()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.h b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.h new file mode 100644 index 0000000000..ffaf77c925 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIUtils.h @@ -0,0 +1,25 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace LLVMDIUtils { +LLVM::DITypeAttr convertType(MLIRContext *context, mlir::Type type); +LLVM::DITypeAttr convertPtrType(MLIRContext *context, + LLVM::LLVMPointerType pointerType, + mlir::Type pointeeType, DataLayout datalayout); +LLVM::DITypeAttr convertStructType(MLIRContext *context, + LLVM::LLVMStructType structType, + LLVM::DIFileAttr fileAttr, + DataLayout datalayout, int64_t line); +LLVM::DITypeAttr convertArrayType(MLIRContext *context, + LLVM::LLVMArrayType arrayType, + LLVM::DIFileAttr fileAttr, + DataLayout datalayout, int64_t line); +FileLineColLoc extractFileLoc(Location loc, bool getCaller = true); +std::optional calcBitWidth(mlir::Type type); +} // namespace LLVMDIUtils +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 0000000000..a3c6d69959 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHIIt()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h b/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 0000000000..1dcdb2992c --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/mthreads/lib/Tools/CMakeLists.txt b/third_party/mthreads/lib/Tools/CMakeLists.txt new file mode 100644 index 0000000000..611468b9a2 --- /dev/null +++ b/third_party/mthreads/lib/Tools/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonTools + GenericSwizzling.cpp + LayoutUtils.cpp + LinearLayout.cpp + PluginUtils.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + f2reduce +) diff --git a/third_party/mthreads/lib/Tools/GenericSwizzling.cpp b/third_party/mthreads/lib/Tools/GenericSwizzling.cpp new file mode 100644 index 0000000000..fedd25a3c3 --- /dev/null +++ b/third_party/mthreads/lib/Tools/GenericSwizzling.cpp @@ -0,0 +1,713 @@ +#include "triton/Tools/GenericSwizzling.h" + +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +#define DEBUG_TYPE "generic-swizzling" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_ctzll(unsigned long long x) { + unsigned long r; + _BitScanForward64(&r, x); + return static_cast(r); +} + +#endif + +void printBasis(const llvm::SmallVector &basis, + const std::string &name) { + llvm::errs() << name << ": "; + for (int32_t b : basis) + llvm::errs() << b << " "; + llvm::errs() << "\n"; +} + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +// Goes from bases of the form [[1], [2], [4], [8]] to [1, 2, 4, 8] +SmallVector flatten(const LinearLayout &ll, StringAttr dim) { + assert(ll.getNumOutDims() == 1); + auto outDim = *ll.getOutDimNames().begin(); + SmallVector vec; + for (int i = 0; i < ll.getInDimSizeLog2(dim); ++i) + vec.push_back(ll.getBasis(dim, i, outDim)); + return vec; +}; + +SmallVector removeZeros(ArrayRef vec) { + SmallVector result; + for (int32_t r : vec) { + if (r != 0) { + result.push_back(r); + } + } + return result; +} + +// [1, 2, 4, 8] -> [[1], [2], [4], [8]] +std::vector> unflatten(ArrayRef basis) { + std::vector> unflattened; + for (int32_t b : basis) + unflattened.push_back({b}); + return unflattened; +} + +// Compute the nullspace basis of `vectors` +SmallVector nullspaceBasis(ArrayRef vectors, int32_t dim) { + // Solve A^T x = 0, where A is the matrix of vectors + // To do this, we form a matrix where each vector is a row + const int32_t nRows = vectors.size(); + auto mat = std::make_unique(nRows); + for (int i = 0; i < nRows; ++i) + mat[i] = static_cast(vectors[i]); + f2reduce::inplace_rref_strided(mat.get(), /*rows=*/nRows, /*cols=*/dim, + /*stride=*/1); + + llvm::SmallDenseSet pivotCols; + for (int32_t r = 0; r < nRows; ++r) + if (mat[r]) + pivotCols.insert(__builtin_ctzll(mat[r])); + + SmallVector basis; + for (int32_t freeCol = 0; freeCol < dim; ++freeCol) { + if (!pivotCols.contains(freeCol)) { + uint64_t vec = 1ull << freeCol; + for (int32_t r = 0; r < nRows; ++r) + if (mat[r] & (1ull << freeCol)) { + const int32_t pivot = __builtin_ctzll(mat[r]); + vec ^= (1ull << pivot); + } + basis.push_back(static_cast(vec)); + } + } + return basis; +} + +// Find the smallest tile that we can read and write to smem +// without sacrificing vectorisation and split it into its own +// `reps` dimension +LinearLayout buildReps(MLIRContext *ctx, const LinearLayout &src, + const LinearLayout &dst, const LinearLayout &smem, + int32_t leaveReps) { + auto kVec = StringAttr::get(ctx, "vector"); + auto kBank = StringAttr::get(ctx, "bank"); + auto kSegment = StringAttr::get(ctx, "segment"); + auto kReps = StringAttr::get(ctx, "reps"); + auto kReg = StringAttr::get(ctx, "register"); + // A basis is a rep if: + // 1) It is in registers in both src and dst + // 2) It is in the segment of smem (i.e., is not part of just one + // load/store) + SetVector srcRegs(llvm::from_range_t{}, flatten(src, kReg)); + SetVector dstRegs(llvm::from_range_t{}, flatten(dst, kReg)); + SetVector smemSegment(llvm::from_range_t{}, flatten(smem, kSegment)); + SetVector segment; + SetVector reps; + for (auto s : smemSegment) { + // Do not move the first leaveReps bases from reps to segment + // as we need them to vectorise the instructions (think .x2 and .x4 in + // ldmatrix) + if (srcRegs.contains(s) && dstRegs.contains(s)) { + if (leaveReps > 0) { + leaveReps--; + segment.insert(s); + } else { + reps.insert(s); + } + } else { + segment.insert(s); + } + } + + auto smemReps = LinearLayout({{kVec, smem.getBases().lookup(kVec)}, + {kBank, smem.getBases().lookup(kBank)}, + {kSegment, unflatten(to_vector(segment))}, + {kReps, unflatten(to_vector(reps))}}, + smem.getOutDims(), + /*requireSurjective=*/true); + return smemReps; +} + +SmallVector computeSegment(const SmallVector &bankSrc, + const SmallVector &bankDst, + int32_t dim, int32_t lenSegment) { + llvm::SmallDenseSet setSrc(bankSrc.begin(), bankSrc.end()); + llvm::SmallDenseSet setDst(bankDst.begin(), bankDst.end()); + // Remove the 0 as it's not a basis + setSrc.erase(0); + setDst.erase(0); + + SmallVector segment; + for (int32_t b = 0; b < dim; ++b) + if (!setSrc.contains(1 << b) && !setDst.contains(1 << b)) + segment.push_back(1 << b); + if (segment.size() >= lenSegment) { + segment.resize(lenSegment); + return segment; + } + + // A and B are the difference sets + SmallVector A, B; + for (int32_t v : setSrc) + if (!setDst.contains(v)) + A.push_back(v); + for (int32_t v : setDst) + if (!setSrc.contains(v)) + B.push_back(v); + if (A.size() > B.size()) { + std::swap(A, B); + } + llvm::sort(A); + llvm::sort(B); + // A is the smaller set now + auto logBankConflicts = std::min( + std::max(0, lenSegment - A.size() - segment.size()), A.size()); + // Conflict-free + for (int i = logBankConflicts; i < A.size(); ++i) + segment.push_back(A[i] ^ B[i]); + // Write conflicts + segment.append(A.begin(), A.begin() + logBankConflicts); + // Read conflicts + segment.append(B.begin(), B.begin() + logBankConflicts); + + if (segment.size() > lenSegment) + segment.resize(lenSegment); + return segment; +} + +SmallVector complementBasis(ArrayRef basis, int32_t dim) { + const int32_t nRows = basis.size(); + auto mat = std::make_unique(nRows); + for (int r = 0; r < nRows; ++r) + mat[r] = static_cast(basis[r]); + + f2reduce::inplace_rref_strided(mat.get(), /*rows=*/nRows, + /*cols=*/dim, /*stride=*/1); + + llvm::SmallDenseSet pivotCols; + for (int r = 0; r < nRows; ++r) { + if (mat[r]) { + pivotCols.insert(__builtin_ctzll(mat[r])); // leading-1 position + } + } + + SmallVector comp; + for (int i = 0; i < dim; ++i) + if (!pivotCols.contains(i)) + comp.push_back(1 << i); + + return comp; +} +} // namespace + +namespace mlir::triton::gpu { + +SmallVector intersectionBasis(ArrayRef b1, + ArrayRef b2, int32_t dim) { + // If needed to be generic, this can be done computing + // nullspaceBasis(concat(nullspaceBasis(b1), nullspaceBasis(b2))) + // but doing this returns the bases in an arbitrary order! + auto isPowerOf2 = [](int32_t x) { return llvm::isPowerOf2_32(x); }; + bool powerOf2 = llvm::all_of(b1, isPowerOf2) && llvm::all_of(b2, isPowerOf2); + if (powerOf2) { + SmallVector result; + // Heuristic: We choose to retain the order relative to b1 + SetVector set2(b2.begin(), b2.end()); + for (int32_t b : b1) { + if (b != 0 && set2.contains(b)) { + result.push_back(b); + } + } + return result; + } else { + auto ns1 = nullspaceBasis(b1, dim); + auto ns2 = nullspaceBasis(b2, dim); + auto joint = llvm::to_vector(llvm::concat(ns1, ns2)); + return nullspaceBasis(joint, dim); + } +} + +std::pair bankConflicts(ArrayRef tileSrc, + ArrayRef tileDst, + const LinearLayout &smem) { + auto *ctx = smem.getOutDimNames().begin()->getContext(); + auto smemFlat = smem.flattenOuts(); + auto inDim = *smem.getInDimNames().begin(); + // Look at the intersection between the segment bases and the tile bases + // We don't need to intersect with the bases that covert the bank (as in + // the first 32 / bitwidth bases) because if we hit any of those broadcasting + // will avoid the bank conflict + auto segment = StringAttr::get(ctx, "segment"); + auto segmentBases = flatten(smemFlat, segment); + + int32_t rank = smem.getTotalOutDimSizeLog2(); + // compute conflicts + int write = 1 << intersectionBasis(segmentBases, tileSrc, rank).size(); + int read = 1 << intersectionBasis(segmentBases, tileDst, rank).size(); + return {read - 1, write - 1}; +} + +std::pair bankConflictsLdSt(const LinearLayout &src, + const LinearLayout &dst, + const LinearLayout &smem, + int32_t bitwidth) { + auto srcFlat = src.flattenOuts(); + auto dstFlat = dst.flattenOuts(); + auto *ctx = smem.getOutDimNames().begin()->getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + auto kVec = S("vector"); + auto srcLane = flatten(srcFlat, S("lane")); + auto dstLane = flatten(dstFlat, S("lane")); + auto log2Vec = + llvm::Log2_32(std::max(smem.getInDimSize(kVec) * bitwidth / 32, 1)); + srcLane.resize(srcLane.size() - log2Vec); + dstLane.resize(dstLane.size() - log2Vec); + return bankConflicts(srcLane, dstLane, smem); +} + +int bankConflictsMemDesc(const LinearLayout ®, const LinearLayout &smem, + int32_t bitwidth) { + auto *ctx = smem.getInDimNames().begin()->getContext(); + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + + assert(smem.hasInDim(S("offset")) && "shared layout must have an offset dim"); + assert(reg.hasInDim(S("register")) && + "register layout must have a register dim"); + auto regNoBroadcast = actionRemoveBroadcastedRegs(reg).apply(reg); + auto regToShared = regNoBroadcast.invertAndCompose(smem); + auto [elemsPerVec, permutation] = + largestVectorisation(ctx, regToShared, bitwidth); + regNoBroadcast = permutation.apply(regNoBroadcast); + + int32_t vecSize = elemsPerVec; + int32_t bankSize = + std::min(32 * 32 / (vecSize * bitwidth), smem.getTotalInDimSize()); + int32_t segmentSize = smem.getTotalInDimSize() / (bankSize * vecSize); + SmallVector> newInDims = { + {S("vector"), vecSize}, + {S("bank"), bankSize}, + {S("segment"), segmentSize}, + }; + auto smemReshaped = smem.reshapeIns(newInDims); + return bankConflictsLdSt(regNoBroadcast, regNoBroadcast, smemReshaped, + bitwidth) + .first; +} + +std::optional> optimalSwizzlingTile( + const LinearLayout &a, const LinearLayout &b, int32_t nRegA, int32_t nRegB, + ArrayRef laneIdTileA, ArrayRef laneIdTileB) { + // For now se just implement the .v4 variants for all the instructions + // We could generalise this in the future + assert(nRegA + laneIdTileA.size() == nRegB + laneIdTileB.size()); + // normalise nRegA >= nRegB + if (nRegA < nRegB) { + return optimalSwizzlingTile(b, a, nRegB, nRegA, laneIdTileB, laneIdTileA); + } + assert(nRegA >= nRegB); + + auto *ctx = a.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto dim = a.getTotalOutDimSizeLog2(); + // map from b to a + LinearLayout cvt = b.invertAndCompose(a); + + // The contiguous tile of ld.shared.b32.v4 for a packed element of size + // bitwidth is composed of 128/bitwidth register elements + // The contiguous tile of ldmatrix.v4 for a packed element of size bitwidth + // is composed of 32/bitwidth register elements and the bases 0, 1st as given + // by the laneAddr + // The contiguous tile of ldmatrix.v4.trans for a packed element of size 16 + // is composed of the bases 2, 3, 4th as given by the laneAddr + + // Note that for register elements, we can choose any register basis we want, + // but the lane bases are fixed + + // In this function, we compute a tile (set of bases) such that it matches + // the tiles of A and B + + auto regA = flatten(a, kReg); + auto regB = flatten(b, kReg); + auto laneA = flatten(a, kLane); + auto laneB = flatten(b, kLane); + + // Compute the number of registers that start the tile + SmallVector vbasis = intersectionBasis(regA, regB, dim); + // We need to have at least nRegB vectorisation + if (vbasis.size() < nRegB) { + return std::nullopt; + } + vbasis.resize(nRegB); + + auto index = [](ArrayRef lane, ArrayRef laneIdTile) { + SmallVector ret; + for (auto id : laneIdTile) { + ret.push_back(lane[id]); + } + return ret; + }; + auto laneTileA = index(laneA, laneIdTileA); + auto laneTileB = index(laneB, laneIdTileB); + + // We need the tiles to be contiguous + auto isZero = [](int32_t b) { return b == 0; }; + if (llvm::any_of(laneTileA, isZero) || llvm::any_of(laneTileB, isZero)) { + return std::nullopt; + } + // The first lanes must map to registers in A + for (int i = 0; i < nRegA - nRegB; ++i) { + if (cvt.getBasis(kLane, laneIdTileB[i], kReg) == 0) { + return std::nullopt; + } + } + // The rest of the lanes must map to each other + for (auto [idxA, idxB] : + llvm::zip(laneIdTileA, laneIdTileB.take_back(laneIdTileA.size()))) { + if (cvt.getBasis(kLane, idxB, kLane) != (1 << idxA)) { + return std::nullopt; + } + } + vbasis.append(laneTileB.begin(), laneTileB.end()); + return vbasis; +} + +LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst, + int32_t bitwidth, ArrayRef vbasis, + ArrayRef tileSrc, + ArrayRef tileDst, + ArrayRef> outDims, + int32_t leaveReps = 0) { + // We work on the flattened tensors as the tensor dimensions are not relevant + assert(src.getNumOutDims() == 1 && dst.getNumOutDims() == 1 && + "src and dst must have a single output dimension"); + + const int32_t dim = src.getTotalOutDimSizeLog2(); + auto *ctx = src.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + + auto regsNotZero = [kReg](const LinearLayout &ll) { + return llvm::all_of( + ll.getBases().lookup(kReg), + [](const std::vector &basis) { return basis[0] != 0; }); + }; + assert( + regsNotZero(src) && + "Remove register broadcasting from src. See actionRemoveBroadcastedRegs"); + assert( + regsNotZero(dst) && + "Remove register broadcasting from dst. See actionRemoveBroadcastedRegs"); + + llvm::SmallVector bankSrc; + bankSrc.append(vbasis.begin(), vbasis.end()); + bankSrc.append(tileSrc.begin(), tileSrc.end()); + llvm::SmallVector bankDst; + bankDst.append(vbasis.begin(), vbasis.end()); + bankDst.append(tileDst.begin(), tileDst.end()); + + // Bits in a bank segment: 32 banks x 32 bits + constexpr int32_t bankBits = 32 * 32; + // Bases needed to cover a whole bank segment + const int32_t lenBbasis = std::min( + llvm::Log2_32(bankBits / ((1 << vbasis.size()) * bitwidth)), + dim - vbasis.size()); + // Bases to cover all the tensor + const int32_t lenSbasis = dim - lenBbasis - vbasis.size(); + + auto sbasis = computeSegment(bankSrc, bankDst, dim, lenSbasis); + + // The bank is the complement of the union of the vector and the start of the + // segments + SmallVector unionBasis; + unionBasis.append(vbasis.begin(), vbasis.end()); + unionBasis.append(sbasis.begin(), sbasis.end()); + SmallVector bbasis = complementBasis(unionBasis, dim); + + assert(bbasis.size() == lenBbasis + (lenSbasis - sbasis.size()) && + "bbasis size mismatch"); + + // Build the 1D result layout + StringAttr vecAttr = StringAttr::get(ctx, "vector"); + StringAttr bankAttr = StringAttr::get(ctx, "bank"); + StringAttr segAttr = StringAttr::get(ctx, "segment"); + + // src has just 1 outDim + LinearLayout basis1D({{vecAttr, unflatten(vbasis)}, + {bankAttr, unflatten(bbasis)}, + {segAttr, unflatten(sbasis)}}, + src.getOutDims(), /*requireSurjective=*/true); + basis1D = buildReps(ctx, src, dst, basis1D, leaveReps); + + return basis1D.reshapeOuts(outDims); +} +LinearLayout optimalSwizzlingLdSt(const LinearLayout &src, + const LinearLayout &dst, int32_t bitwidth) { + auto *ctx = src.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto srcFlat = src.flattenOuts(); + auto dstFlat = dst.flattenOuts(); + auto regSrc = flatten(srcFlat, kReg); + auto regDst = flatten(dstFlat, kReg); + auto laneSrc = flatten(srcFlat, kLane); + auto laneDst = flatten(dstFlat, kLane); + auto dim = src.getTotalOutDimSizeLog2(); + SmallVector vbasis = intersectionBasis(regSrc, regDst, dim); + // Restrict the vectorisation to the maximum we can use + auto maxVecBases = llvm::Log2_32(128 / bitwidth); + if (vbasis.size() > maxVecBases) { + vbasis.resize(maxVecBases); + } + // We fill-up vbasis until it has 32 bits as best we can + std::optional srcFillsBank = std::nullopt; + if ((1 << vbasis.size()) * bitwidth < 32) { + auto basesPerBank = llvm::Log2_32(32 / bitwidth); + auto kWarp = StringAttr::get(ctx, "warp"); + auto warpSrc = removeZeros(flatten(srcFlat, kWarp)); + auto warpDst = removeZeros(flatten(dstFlat, kWarp)); + auto removeVec = [&vbasis](ArrayRef vec) { + SmallVector result; + for (int32_t r : vec) { + if (!llvm::is_contained(vbasis, r)) { + result.push_back(r); + } + } + return result; + }; + auto regSrcWarp = intersectionBasis(removeVec(regSrc), warpDst, dim); + auto regDstWarp = intersectionBasis(removeVec(regDst), warpSrc, dim); + // Maximise vectorisation in the load or the store without creating + // conflicts + SmallVector largest; + if (regSrcWarp.size() == regDstWarp.size() && regSrcWarp.size() > 0) { + // We choose the one with the lowest basis in the hope that it will + // avoid PRMTs. The comparison of the mins will be strict as the sets + // removeVec(regSrc) and removeVec(regDst) don't intersect + if (*llvm::min_element(regSrcWarp) < *llvm::min_element(regDstWarp)) { + largest = regSrcWarp; + srcFillsBank = true; + } else { + largest = regDstWarp; + srcFillsBank = false; + } + } else { + srcFillsBank = regSrcWarp.size() > regDstWarp.size(); + largest = srcFillsBank.value() ? regSrcWarp : regDstWarp; + } + vbasis.append(largest.begin(), largest.end()); + + if (vbasis.size() < basesPerBank) { + // Pad the vectorisation to 32 bits with warp bases + auto warpSrcWarp = intersectionBasis(warpSrc, warpDst, dim); + vbasis.append(warpSrcWarp.begin(), warpSrcWarp.end()); + } + + int i = 0; + while (vbasis.size() < basesPerBank && + (i < warpSrc.size() || i < warpDst.size())) { + // If we have not filled up a whole bank, we add more warp bases + // until we have 32 bits. They will at least avoid bank conflicts in one + // direction + if (i < warpSrc.size() && !llvm::is_contained(vbasis, warpSrc[i])) { + vbasis.push_back(warpSrc[i]); + } + if (vbasis.size() < basesPerBank && i < warpDst.size() && + !llvm::is_contained(vbasis, warpDst[i])) { + vbasis.push_back(warpDst[i]); + } + ++i; + } + + // Trim to basesPerBank if we have added more + // The idea here is that implementing asymmetric vectorisation without bank + // conflicts is a bit tricky. Basically, in this case, you need to use the + // vectorisation base in the swizzling pattern. As such, you would not be + // able to vectorise all the `ld.shared` instructions that you emit, but + // just about half of them (the ones that are not swizzled). We don't + // implement this yet + if (vbasis.size() > basesPerBank) { + vbasis.resize(basesPerBank); + } + } + auto log2Vec = llvm::Log2_32( + std::max(1, ((1 << vbasis.size()) * bitwidth) / 32)); + auto tileSrc = to_vector(ArrayRef(laneSrc).drop_back(log2Vec)); + auto tileDst = to_vector(ArrayRef(laneDst).drop_back(log2Vec)); + auto smem = optimalSwizzling(srcFlat, dstFlat, bitwidth, vbasis, tileSrc, + tileDst, src.getOutDims()); + + // We might be able to vectorise a bit more the load or the store + // This may happen when there is broadcasting + // e.g for fp32 + // src = {reg = [], lane = [1, 2, 4, 8, 16], warp = [32]} + // dst = {reg = [8, 32], lane = [0, 0, 1, 2, 4], warp = [16]} + if (log2Vec < 2) { + auto smemFlat = smem.flattenOuts(); + // For every bank line, find if it is in regSrc or regDst + // and if so, store the index in the vector + SmallVector idxBanksInRegSrc; + SmallVector idxBanksInRegDst; + auto kBank = StringAttr::get(ctx, "bank"); + const auto &banks = flatten(smemFlat, kBank); + for (auto [i, r] : llvm::enumerate(banks)) { + if (llvm::is_contained(regSrc, r)) { + idxBanksInRegSrc.push_back(i); + } + if (llvm::is_contained(regDst, r)) { + idxBanksInRegDst.push_back(i); + } + } + + // Choose src/dst if we used them to fill the bank + // Otherwise choose the max vectorisation + SmallVector bBasisOrder; + if (srcFillsBank.has_value() && srcFillsBank.value()) { + bBasisOrder = std::move(idxBanksInRegSrc); + } else if (srcFillsBank.has_value() && !srcFillsBank.value()) { + bBasisOrder = std::move(idxBanksInRegDst); + } else { + bBasisOrder = idxBanksInRegSrc.size() > idxBanksInRegDst.size() + ? std::move(idxBanksInRegSrc) + : std::move(idxBanksInRegDst); + } + for (int i = 0; i < banks.size(); ++i) { + if (!llvm::is_contained(bBasisOrder, i)) { + bBasisOrder.push_back(i); + } + } + smem = ColumnAction(bBasisOrder, kBank, smem.getInDimSizeLog2(kBank)) + .apply(smem); + } + + return smem; +} + +std::pair> +optimalSwizzling(const LinearLayout &src, const LinearLayout &dst, + ArrayRef srcTiles, + ArrayRef dstTiles, int32_t bitwidth) { + assert(bitwidth <= 128 && "bitwidth must be <= 128"); + auto srcFlat = src.flattenOuts(); + auto dstFlat = dst.flattenOuts(); + // Number of total bases needed to cover the necessary contiguous tile + // We assume using ld.shared.b32.v4 in the case of ld/st ops + const auto totalBases = llvm::Log2_32(128 / bitwidth); + + auto *ctx = src.getInDimNames().begin()->getContext(); + auto kReg = StringAttr::get(ctx, "register"); + + // Find the pairs of instructions that we can use to lower this converet + SmallVector, SmallVector>> + instr; + for (const auto &[idxSrc, instrSrc] : llvm::enumerate(srcTiles)) { + auto logRegSrc = totalBases - instrSrc.laneContig.size(); + for (const auto &[idxDst, instrDst] : llvm::enumerate(dstTiles)) { + auto logRegDst = totalBases - instrDst.laneContig.size(); + auto maybeTile = + optimalSwizzlingTile(srcFlat, dstFlat, logRegSrc, logRegDst, + instrSrc.laneContig, instrDst.laneContig); + if (maybeTile.has_value()) { + instr.push_back({{idxSrc, idxDst}, std::move(*maybeTile)}); + } + } + } + auto getTile = + [](const LocalMemOpTile &instr, ArrayRef regs, + ArrayRef lane, + ArrayRef vbasis) -> std::optional> { + // pick the first 3 - laneAddr.size() registers that are not in vbasis + SmallVector tile; + auto regNeeded = 3 - instr.laneAddr.size(); + assert(regNeeded >= 0 && "laneAddr.size() must be <= 3"); + for (int32_t r : regs) { + if (regNeeded == 0) { + break; + } + if (!llvm::is_contained(vbasis, r)) { + tile.push_back(r); + regNeeded--; + } + } + // Not enough registers to fill in the tile + if (regNeeded > 0) { + return std::nullopt; + } + for (auto i : instr.laneAddr) { + tile.push_back(lane[i]); + } + return tile; + }; + + auto kLane = StringAttr::get(ctx, "lane"); + auto regSrc = flatten(srcFlat, kReg); + auto regDst = flatten(dstFlat, kReg); + auto laneSrc = flatten(srcFlat, kLane); + auto laneDst = flatten(dstFlat, kLane); + + // Get the associated src/dst tiles for each instruction if they exist + SmallVector, SmallVector, + SmallVector, SmallVector, int32_t>> + tiles; + for (auto [instrs, vbasis] : instr) { + auto maybeTileSrc = + getTile(srcTiles[instrs.first], regSrc, laneSrc, vbasis); + auto maybeTileDst = + getTile(dstTiles[instrs.second], regDst, laneDst, vbasis); + if (!maybeTileSrc.has_value() || !maybeTileDst.has_value()) { + continue; + } + // Regs bases missing to get full vectorisation + auto regsMissing = [](const LocalMemOpTile &instr) { + return instr.laneContig.size() + instr.laneAddr.size() - 3; + }; + // We leave 2 reps for combinations of ldmatrix/stmatrix instructions + // to be able to fully vectorise them + int32_t leaveReps = std::min(regsMissing(srcTiles[instrs.first]), + regsMissing(dstTiles[instrs.second])); + assert((leaveReps == 0 || leaveReps == 2) && "leaveReps must be 0 or 2"); + tiles.push_back({instrs, std::move(vbasis), std::move(*maybeTileSrc), + std::move(*maybeTileDst), leaveReps}); + } + + if (tiles.empty()) { + // We lower to an ld / st, but can't use LDS128/STS128 + auto smem = optimalSwizzlingLdSt(src, dst, bitwidth); + return {smem, {0, 0}}; + } else { + SmallVector>> + smems; + // We choose the pair of instructions that minimises the total bank + // conflicts + for (auto [instrs, vbasis, tileSrc, tileDst, leaveReps] : tiles) { + auto smem = optimalSwizzling(srcFlat, dstFlat, bitwidth, vbasis, tileSrc, + tileDst, src.getOutDims(), leaveReps); + auto [read, write] = bankConflicts(tileSrc, tileDst, smem); + smems.push_back({read + write, smem, {instrs.first, instrs.second}}); + } + // Current heuristic: Minimise total bank conflicts + // We break ties looking at the number of rounds we do to move the data + auto kReps = StringAttr::get(ctx, "reps"); + auto it = llvm::min_element(smems, [kReps](const auto &a, const auto &b) { + return std::get<0>(a) < std::get<0>(b) || + (std::get<0>(a) == std::get<0>(b) && + std::get<1>(a).getInDimSize(kReps) > + std::get<1>(b).getInDimSize(kReps)); + }); + return {std::get<1>(*it), std::get<2>(*it)}; + } +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Tools/LayoutUtils.cpp b/third_party/mthreads/lib/Tools/LayoutUtils.cpp new file mode 100644 index 0000000000..815bf6d4b3 --- /dev/null +++ b/third_party/mthreads/lib/Tools/LayoutUtils.cpp @@ -0,0 +1,582 @@ +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/GenericSwizzling.h" + +namespace mlir::triton { + +static bool checkSquareSublayout(const LinearLayout &ll, + ArrayRef dimNames, + function_ref checkBasis) { + // The empty layout is the identity + if (dimNames.size() == 0) { + return true; + } + // Check that the input-output sizes are the same + LinearLayout sl = ll.sublayout(dimNames, dimNames); + for (StringAttr dim : dimNames) { + if (ll.getInDimSize(dim) != ll.getOutDimSize(dim)) { + return false; + } + } + // Once the inputs and output dimensions are the same, we can just check + // that the basis for the single remaining dimension is the identity. + sl = sl.flattenIns().flattenOuts(); + const auto &inDimBases = sl.getBases().begin()->second; + for (auto [b, basis] : llvm::enumerate(inDimBases)) { + if (!checkBasis(b, basis[0])) { + return false; + } + } + return true; +} + +bool squareSublayoutIsIdentity(const LinearLayout &ll, + ArrayRef dimNames) { + return checkSquareSublayout( + ll, dimNames, [](int b, int32_t basis) { return basis == (1 << b); }); +} + +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + + auto kRegister = StringAttr::get(ctx, "register"); + std::set broadcastedDims; + + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + if (!broadcastRegisters && inDimName == kRegister) { + broadcastedDims.insert(basisIdx); + } else { + bases[inDimName][basisIdx][outDim.index()] = 0; + } + actualSize >>= 1; + } + } + if (!broadcastRegisters) { + // Remove broadcasted registers + std::vector> newBasesRegister; + for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { + // Remove if it's broadcasted + if (broadcastedDims.find(idx) == broadcastedDims.end()) { + newBasesRegister.push_back(std::move(basis)); + } + } + bases[kRegister] = std::move(newBasesRegister); + } + auto outDims = layout.getOutDims(); + for (auto &[outDim, outDimSize] : outDims) { + outDimSize = std::min(outDimSize, shape.lookup(outDim)); + } + + return LinearLayout(std::move(bases), std::move(outDims), + /*requireSurjective=*/false); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Returns ["dim0", "dim1", ..., "dim"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns [("dim0", dstShape[0]), ("dim1", dstShape[1]), ..., +// ("dim", dstShape[rank-1])]. +SmallVector> +standardOutDimPairs(MLIRContext *ctx, ArrayRef dstShape) { + auto newRank = dstShape.size(); + SmallVector> newOutDims; + for (auto [dim, size] : + llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { + newOutDims.emplace_back(dim, size); + } + return newOutDims; +} + +// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to +// creating a 1D -> 1D mapping of size product(shape) and then reshaping to +// permute(shape, order). +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + MLIRContext *ctx = inDimName.getContext(); + auto rank = shape.size(); + + // The order in triton is written wrt. [dim0, dim1, ...]. + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +LinearLayout zerosLike(const LinearLayout &layout) { + auto bases = layout.getBases(); + for (auto &basis : bases) { + for (auto &vec : basis.second) { + for (auto &val : vec) { + val = 0; + } + } + } + + SmallVector> outDims; + for (auto outDim : layout.getOutDimNames()) { + outDims.emplace_back(outDim, layout.getOutDimSize(outDim)); + } + return LinearLayout(std::move(bases), std::move(outDims), + /*requireSurjective=*/false); +} + +std::optional regPermForDivide(const LinearLayout &A, + const LinearLayout &B, bool left) { + // We can implement this generically for any dimension, but for now we only do + // it for regs to keep the API simpler + assert(A.getNumInDims() != 0); + auto kReg = *A.getInDimNames().begin(); + assert(kReg.str() == "register"); + assert(B.getNumInDims() != 0); + assert(kReg == *B.getInDimNames().begin()); + + // We broadcast B to have the same number of out dims as A. + LinearLayout broadcast; + for (StringAttr out : A.getOutDimNames()) { + broadcast *= LinearLayout::identity1D(1, kReg, out); + } + auto BBroadcast = broadcast * B; + + // Retrieve the register bases from A and B. + const auto &ARegBases = A.getBases().lookup(kReg); + const auto &BRegBases = BBroadcast.getBases().lookup(kReg); + + llvm::DenseMap log2QuotSize; + for (StringAttr out : A.getOutDimNames()) { + log2QuotSize[out] = + A.getOutDimSizeLog2(out) - BBroadcast.getOutDimSizeLog2(out); + if (log2QuotSize[out] < 0) + return std::nullopt; + } + + auto multiplyByTileSize = + [&](ArrayRef bBasis) -> std::vector { + std::vector result; + size_t idx = 0; + assert(bBasis.size() == A.getNumOutDims()); + for (auto [dim, b] : llvm::zip(A.getOutDimNames(), bBasis)) { + result.push_back(b << log2QuotSize.lookup(dim)); + } + return result; + }; + + // Compute the permutation order: + // For each basis in B (in order), find its index in A (using each index at + // most once). We make sure we use each index at most once in case B + // broadcasts (weird case, but better safe than sorry). + SmallVector bIndices; + SmallVector used(ARegBases.size(), false); + for (auto bB : BRegBases) { + bool found = false; + if (!left) + bB = multiplyByTileSize(bB); + + for (size_t j = 0; j < ARegBases.size(); ++j) { + found = !used[j] && (ARegBases[j] == bB); + if (found) { + bIndices.push_back(j); + used[j] = true; + break; + } + } + if (!found) + return std::nullopt; // A basis from B not found in A. + } + // Append remaining indices from A (preserving their original order). + SmallVector remainingIndices; + for (size_t i = 0; i < ARegBases.size(); ++i) { + if (!used[i]) + remainingIndices.push_back(i); + } + SmallVector permOrder = to_vector(llvm::concat( + left ? bIndices : remainingIndices, left ? remainingIndices : bIndices)); + return ColumnAction(permOrder, kReg, ARegBases.size()); +} + +ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout) { + assert(layout.getNumInDims() != 0); + auto kReg = *layout.getInDimNames().begin(); + assert(kReg.str() == "register"); + + // Drop the bases that are zero + const auto &bases = layout.getBases().lookup(kReg); + SmallVector permOrder; + for (size_t i = 0; i < bases.size(); ++i) { + if (!llvm::all_of(bases[i], [](size_t x) { return x == 0; })) { + permOrder.push_back(i); + } + } + return ColumnAction(permOrder, kReg, bases.size()); +} +std::pair +actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout, + uint64_t maskSpanOffsets) { + // We are looking to put at the front (after any zeros) any basis that does + // not intersect with any bit moved by any basis in kLane / kWarp + // and that is not moved by any affine offset + + // Note this function assumes that if any registers are used in the addrLayout + // of the layout (as in ldmatrix/stmatrix) they will be the first non-zero + // registers within `layout` + assert(layout.getNumInDims() != 0); + auto kReg = *layout.getInDimNames().begin(); + assert(kReg.str() == "register"); + auto kLane = StringAttr::get(kReg.getContext(), "lane"); + auto kWarp = StringAttr::get(kReg.getContext(), "warp"); + assert(layout.getNumOutDims() == 1); + uint32_t bits = maskSpanOffsets; + llvm::SetVector tileBases; + for (auto bases : llvm::make_second_range(addrLayout.getBases())) { + for (auto basis : bases) { + bits |= basis[0]; + tileBases.insert(basis[0]); + } + } + SmallVector front, back; + for (auto [idx, basis] : llvm::enumerate(layout.getBases().lookup(kReg))) { + if ((basis[0] & bits) == 0 || tileBases.contains(basis[0])) { + front.push_back(idx); + } else { + back.push_back(idx); + } + } + auto permOrder = to_vector(llvm::concat(front, back)); + return {1 << front.size(), + ColumnAction(permOrder, kReg, layout.getInDimSizeLog2(kReg))}; +} + +SmallVector broadcastAs(const SmallVector &values, + const LinearLayout &layout) { + assert(layout.getNumInDims() != 0); + auto kReg = *layout.getInDimNames().begin(); + assert(kReg.str() == "register"); + uint32_t broadcastMask = layout.getFreeVariableMasks().lookup(kReg); + assert((layout.getInDimSize(kReg) / (1 << llvm::popcount(broadcastMask))) == + values.size()); + + std::vector> newBases; + int i = 0; + for (int j = 0; j < layout.getInDimSizeLog2(kReg); j++) { + if (broadcastMask & (1 << j)) { + newBases.push_back({0}); + } else { + newBases.push_back({1 << i}); + i++; + } + } + auto newLayout = LinearLayout({{kReg, std::move(newBases)}}, {kReg}); + SmallVector ret; + + ret.reserve(newLayout.getInDimSize(kReg)); + for (int i = 0; i < newLayout.getInDimSize(kReg); i++) { + int32_t srcIdx = newLayout.apply({{kReg, i}}).begin()->second; + ret.push_back(values[srcIdx]); + } + return ret; +} + +// Compute the supremum of two lists. +// If the supremum is not unique, we return the first list first +// Error out if the supremum does not exist +// e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c] +// sup([a, b], [b, a]) = error! Supremum does not exist. +SmallVector supremum(const SmallVector &x, + const SmallVector &y) { + llvm::SetVector result; + DenseMap posX, posY; + for (auto [idx, elem] : llvm::enumerate(x)) + posX[elem] = idx; + for (auto [idx, elem] : llvm::enumerate(y)) + posY[elem] = idx; + int i = 0, j = 0; + const int INF = std::numeric_limits::max(); + while (i < x.size() || j < y.size()) { + while (i < x.size() && result.contains(x[i])) + ++i; + while (j < y.size() && result.contains(y[j])) + ++j; + if (i >= x.size() && j >= y.size()) + break; + if (i < x.size() && j < y.size() && x[i] == y[j]) { + if (posY[x[i]] < j) + llvm_unreachable("Supremum does not exist"); + result.insert(x[i]); + ++i, ++j; + continue; + } + int candX = INF, candY = INF; + if (i < x.size()) { + if (posY.count(x[i]) && posY[x[i]] >= j) + candX = posY[x[i]]; + } + if (j < y.size()) { + if (posX.count(y[j]) && posX[y[j]] >= i) + candY = posX[y[j]]; + } + if (i < x.size() && candX == INF) { + result.insert(x[i]); + ++i; + continue; + } + if (j < y.size() && candY == INF) { + result.insert(y[j]); + ++j; + continue; + } + if (candX <= candY) { + if (posY[x[i]] < j) + llvm_unreachable("Supremum does not exist"); + result.insert(x[i]); + ++i; + } else { + if (posX[y[j]] < i) + llvm_unreachable("Supremum does not exist"); + result.insert(y[j]); + ++j; + } + } + return to_vector(result); +} + +LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout, + ArrayRef shape) { + int rank = shape.size(); + auto srcOutDims = to_vector(layout.getOutDimNames()); + std::reverse(srcOutDims.begin(), srcOutDims.end()); + auto newOutDims = standardOutDimPairs(ctx, shape); + std::reverse(newOutDims.begin(), newOutDims.end()); + return layout.transposeOuts(srcOutDims) + .reshapeOuts(newOutDims) + .transposeOuts(standardOutDimNames(ctx, rank)); +} + +LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef order) { + // Transpose the tile layout. + auto namedBases = layout.getBases(); + // move the most outer dimensions to the inner most position. + + for (auto &bases : llvm::make_second_range(namedBases)) { + for (auto &b : bases) { + std::vector newB; + for (auto i : order) { + newB.push_back(b[i]); + } + b = std::move(newB); + } + } + return LinearLayout(std::move(namedBases), + to_vector(layout.getOutDimNames())); +} + +std::pair +largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth, + std::optional maybeMaxVecElems) { + // Find the largest vectorisation we can use: + auto S = [ctx](StringRef str) { return StringAttr::get(ctx, str); }; + StringAttr kReg = S("register"); + StringAttr kOffset = S("offset"); + LinearLayout quot; + LinearLayout tile; + ColumnAction permutation; + // If there are restrictions on the vectorisation, we don't allow + // permutations. + auto allowPerm = !maybeMaxVecElems.has_value(); + auto maxVecElems = maybeMaxVecElems.value_or(128 / bitwidth); + for (int v = maxVecElems; v >= 1; v /= 2) { + tile = LinearLayout::identity1D(v, kReg, kOffset); + auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true); + if (!maybePerm) { + continue; + } + permutation = *maybePerm; + if (!allowPerm && !permutation.isIdentity()) { + continue; + } + auto newCvt = permutation.apply(cvt); + auto maybeQuot = divideLeft(newCvt, tile); + if (!maybeQuot) { + continue; + } + return {v, permutation}; + } + llvm_unreachable("Vectorization < 1 is not valid"); +} + +std::optional getReps(const LinearLayout &cvt, + const LinearLayout &tile) { + + // Ensure tile out-dims are subset of cvt out-dims. + for (auto od : tile.getOutDimNames()) + assert(cvt.hasOutDim(od) && "tile out-dims must be contained in cvt"); + + // Precompute tile out-dim bit-widths. + llvm::SmallDenseMap outBLog2; + for (StringAttr od : cvt.getOutDimNames()) + outBLog2[od] = tile.hasOutDim(od) ? tile.getOutDimSizeLog2(od) : 0; + + // Build a per-out-dimension mask by OR-ing all tile bases that touch it. + llvm::SmallDenseMap tileMaskPerOutDim; + for (StringAttr od : cvt.getOutDimNames()) + tileMaskPerOutDim[od] = 0; + for (auto &[inDim, inBases] : tile.getBases()) { + (void)inDim; + for (auto &basis : inBases) { + int idx = 0; + for (StringAttr od : tile.getOutDimNames()) { + tileMaskPerOutDim[od] |= basis[idx++]; + } + } + } + + // Build reps with the same in/out dims as cvt, but zeroing out the leading + // inB bases (per in-dim) and keeping the remainder bases unchanged from cvt. + LinearLayout::BasesT repsBases; + for (StringAttr id : cvt.getInDimNames()) { + int inA = cvt.getInDimSizeLog2(id); + int inB = tile.hasInDim(id) ? tile.getInDimSizeLog2(id) : 0; + if (inB > inA) { + return std::nullopt; + } + + std::vector> basesForDim; + basesForDim.reserve(inA); + + // 1) Validate the starting bases match exactly. + for (int i = 0; i < inB; ++i) { + for (StringAttr od : cvt.getOutDimNames()) { + int a = cvt.getBasis(id, i, od); + int b = tile.getBasis(id, i, od); + if (a != b) { + return std::nullopt; + } + } + } + + // 2) Validate no overlap: the remaining cvt bases must have zeros in all + // tile-bit positions (computed as OR of all tile bases) for each + // out-dim. + for (int i = inB; i < inA; ++i) { + for (StringAttr od : cvt.getOutDimNames()) { + int32_t mask = tileMaskPerOutDim.lookup(od); + if (mask == 0) + continue; + int v = cvt.getBasis(id, i, od); + if ((v & mask) != 0) { + return std::nullopt; + } + } + } + + // 3) Emit reps bases: first inB as all-zeros; remainder copied from cvt. + for (int i = 0; i < inB; ++i) { + std::vector zero(cvt.getNumOutDims(), 0); + basesForDim.push_back(std::move(zero)); + } + for (int i = inB; i < inA; ++i) { + std::vector keep; + keep.reserve(cvt.getNumOutDims()); + for (StringAttr od : cvt.getOutDimNames()) + keep.push_back(cvt.getBasis(id, i, od)); + basesForDim.push_back(std::move(keep)); + } + + repsBases[id] = std::move(basesForDim); + } + + return LinearLayout(std::move(repsBases), cvt.getOutDims(), + /*requireSurjective=*/false); +} + +LinearLayout removeStandardDim(const LinearLayout &layout, int dim) { + auto rank = layout.getNumOutDims(); + assert(rank > 0); + assert(dim < rank); + auto *ctx = layout.getOutDimNames().begin()->getContext(); + auto dims = to_vector(layout.getOutDimNames()); + assert(dims == standardOutDimNames(ctx, rank)); + dims.erase(dims.begin() + dim); + auto newLayout = layout.sublayout(to_vector(layout.getInDimNames()), dims); + auto dimSizes = newLayout.getOutDims(); + auto newDims = standardOutDimNames(ctx, rank - 1); + for (auto [i, newDim] : llvm::enumerate(newDims)) { + dimSizes[i].first = newDim; + } + return LinearLayout(newLayout.getBases(), dimSizes, /*isSurjective*/ false); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Tools/LinearLayout.cpp b/third_party/mthreads/lib/Tools/LinearLayout.cpp new file mode 100644 index 0000000000..11b4367072 --- /dev/null +++ b/third_party/mthreads/lib/Tools/LinearLayout.cpp @@ -0,0 +1,1407 @@ +#include "triton/Tools/LinearLayout.h" + +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +#define DEBUG_TYPE "linear_layout" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +static int __builtin_ctzll(unsigned long long x) { + unsigned long r; + _BitScanForward64(&r, x); + return static_cast(r); +} + +#endif + +namespace mlir::triton { + +namespace { +using BasesT = LinearLayout::BasesT; +using llvm::SmallDenseSet; +using llvm::Twine; + +BasesT makeBasesMap( + ArrayRef>>> bases) { + BasesT ret; + for (const auto &[inDim, inDimBases] : bases) { + ret[inDim] = inDimBases; + } + return ret; +} + +// Dump the matrix to stderr in a human-readable format for debugging. +void dumpMatrix(uint64_t *m, int numRows, int numCols) { + assert(numCols <= 64); + for (int r = 0; r < numRows; r++) { + llvm::errs() << "0b"; + for (int c = 0; c < numCols; c++) { + llvm::errs() << ((m[r] & (1 << c)) != 0 ? "1" : "0"); + } + llvm::errs() << "\n"; + } +} + +// Compute the rank of the matrix formed by taking the bases for the given +// outDim as columns. In other words, finds the number of linearly-independent +// bases for this output dimension. +int getMatrixRank(std::unique_ptr m, int numRows, int numCols) { + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); + f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); + + // The rank of the reduced matrix is simply the number of nonzero rows. + int rank = 0; + for (int i = 0; i < numRows; i++) { + if (m[i] != 0) + rank++; + } + return rank; +} + +template +void assertDimsEqualIgnoringOrder(T &&a, U &&b) { + SmallDenseSet as(a.begin(), a.end()); + SmallDenseSet bs(b.begin(), b.end()); + if (as != bs) { + llvm::report_fatal_error("Dimensions must match, ignoring order, but they " + "don't. Got dims: [" + + Twine(triton::join(a, ", ")) + "] and [" + + triton::join(b, ", ") + "]"); + } +} + +template +void assertDimsSubsetIgnoringOrder(T &&small, U &&big) { + SmallDenseSet smallSet(small.begin(), small.end()); + SmallDenseSet bigSet(big.begin(), big.end()); + if (!llvm::set_is_subset(smallSet, bigSet)) { + llvm::report_fatal_error("Dimensions must be a subset, ignoring order, but " + "they aren't. Got dims: [" + + Twine(triton::join(small, ", ")) + "] and [" + + triton::join(big, ", ") + "]"); + } +} +} // anonymous namespace + +/*static*/ std::optional +LinearLayout::tryCreate(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) { + LinearLayout ll(std::move(bases), std::move(outDims), NoCheckInvariants{}); + std::optional error = ll.checkInvariants(requireSurjective); + if (error) { + return std::nullopt; + } + return ll; +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + NoCheckInvariants) + : bases(std::move(bases)) { + for (auto [outDim, size] : outDims) { + this->outDims[outDim] = size; + } +} + +LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) + : bases(std::move(bases)) { + // Infer out-dim sizes. + for (StringAttr outDim : outDimNames) { + outDims[outDim] = 1; + } + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + int32_t &size = outDims[outDimNames[i]]; + size = std::max(size, llvm::NextPowerOf2(basis[i])); + } + } + } + + std::optional error = + checkInvariants(/*requireSurjective=*/true); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) + : LinearLayout(std::move(bases), std::move(outDims), NoCheckInvariants{}) { + std::optional error = checkInvariants(requireSurjective); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +std::optional +LinearLayout::checkInvariants(bool requireSurjective) { + LDBG("checkInvariants: " << toString()); + // Check that basis values are non-negative. + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + return "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + inDim.str() + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDims.size()) { + return "Invalid bases passed to LinearLayout. Expect all bases to " + "have the same size, equal to outDimNames.size() (" + + std::to_string(outDims.size()) + + "). But this failed for in dimension '" + inDim.str() + + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the out-dim sizes are powers of 2. + for (const auto &[outDim, size] : outDims) { + if (!llvm::isPowerOf2_32(size)) { + return "Invalid out-dim size " + std::to_string(size) + " for out-dim '" + + outDim.str() + "'. Out-dim sizes must be powers of 2.\n"; + } + } + + // Check that the bases are smaller than the out-dim sizes. + SmallVector outDimNames = llvm::to_vector(getOutDimNames()); + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + if (basis[i] >= outDims[outDimNames[i]]) { + return "Invalid basis " + std::to_string(basis[i]) + " for in-dim '" + + inDim.str() + "' and out-dim '" + outDimNames[i].str() + + "'. Basis must be less than the out-dim size.\n"; + } + } + } + } + + // Determine whether the this layout is surjective, i.e. that every `out` + // coordinate can be reached by some `in` coordinate. + // + // It's prohibitively slow to calculate this naively, but thankfully, this + // is equivalent to checking that the number of linearly-independent bases + // is equal to sum(getOutDimSizeLog2). This can be computed by finding + // the rank of the matrix whose columns are those bases. We can compute + // the rank of our matrix using Gaussian elimination, which runs in O(n^3) + // for an n x n matrix. Our matrix size is sum(inDimSizeLog2) x + // sum(outDimSizeLog2), so this should be plenty fast. + this->rank = + getMatrixRank(getMatrix(*this), /*numRows=*/getTotalOutDimSizeLog2(), + /*numCols=*/getTotalInDimSizeLog2()); + + if (requireSurjective && !isSurjective()) { + return "Layout is expected to be surjective, i.e. every `out` coordinate " + "can be reached by some `in` coordinate, but was not:" + + toString(); + } + + return std::nullopt; +} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames) + : LinearLayout(makeBasesMap(bases), outDimNames) {} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective) + : LinearLayout(makeBasesMap(bases), outDims, requireSurjective) {} + +/*static*/ LinearLayout LinearLayout::strided1D(int32_t size, int32_t stride, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> bases; + for (int32_t i = 1; i < size; i *= 2) { + bases.emplace_back(std::vector{i * stride}); + } + bool requiresSurjective = (stride == 1); + return LinearLayout({{inDimName, std::move(bases)}}, + {{outDimName, stride * size}}, requiresSurjective); +} + +/*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName, + int32_t outDimSize) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> zeros; + for (int i = 1; i < size; i *= 2) { + zeros.emplace_back(std::vector{0}); + } + return LinearLayout({{inDimName, zeros}}, {{outDimName, outDimSize}}, + /*requiresSurjective=*/outDimSize == 1); +} + +int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { + int i = 0; + for (auto [name, _] : outDims) { + if (name == outDim) { + return i; + } + i++; + } + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout" + + toString()); +} + +int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { + auto it = bases.find(inDim); + assert(it != bases.end() && "inDim not found in layout"); + return it->second.size(); +} + +int32_t LinearLayout::getTotalInDimSizeLog2() const { + return std::accumulate(getInDimNames().begin(), getInDimNames().end(), 0, + [&](int32_t acc, StringAttr inDim) { + return acc + getInDimSizeLog2(inDim); + }); +} + +int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { + auto it = outDims.find(outDim); + assert(it != outDims.end() && "outDim not found in layout"); + return llvm::Log2_32(it->second); +} + +int32_t LinearLayout::getTotalOutDimSizeLog2() const { + return std::accumulate(getOutDimNames().begin(), getOutDimNames().end(), 0, + [&](int32_t acc, StringAttr outDim) { + return acc + getOutDimSizeLog2(outDim); + }); +} + +int32_t LinearLayout::getNumConsecutiveInOut() const { + if (bases.empty() || getNumOutDims() == 0) + return 1; + + // Count how many of the initial bases for the first in-dim are + // (2^i, 0, ..., 0). + const auto &firstInDimBases = bases.begin()->second; + int consec = 0; + for (; consec < firstInDimBases.size(); consec++) { + const auto &basis = firstInDimBases[consec]; + if (basis[0] != (1 << consec) || + !std::all_of(basis.begin() + 1, basis.end(), + [](int32_t x) { return x == 0; })) { + break; + } + } + + // `or` together all other bases' first out-dim. + int32_t otherBits = 0; + for (const auto &[inDim, inDimBases] : bases) { + for (int i = 0; i < inDimBases.size(); i++) { + if (inDim != bases.begin()->first || i >= consec) { + otherBits |= inDimBases[i][0]; + } + } + } + int32_t trailingZeros = otherBits != 0 ? __builtin_ctz(otherBits) : 31; + + return 1 << std::min(consec, trailingZeros); +} + +LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { + assertDimsEqualIgnoringOrder(newInDims, getInDimNames()); + + BasesT newBases; + for (const auto &inDim : newInDims) { + newBases[inDim] = bases.find(inDim)->second; + } + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + isSurjective()); +} + +LinearLayout +LinearLayout::transposeOuts(ArrayRef newOutDims) const { + assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames()); + + std::vector permutation; + for (const auto &outDim : newOutDims) { + permutation.push_back(getOutDimIndex(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + std::vector newBasis; + for (int32_t i : permutation) { + newBasis.push_back(basis[i]); + } + newInDimBases.push_back(std::move(newBasis)); + } + } + + SmallVector> newOutDimSizes; + for (auto outDim : newOutDims) { + newOutDimSizes.push_back({outDim, getOutDimSize(outDim)}); + } + return LinearLayout(std::move(newBases), newOutDimSizes, isSurjective()); +} + +LinearLayout LinearLayout::reshapeIns( + ArrayRef> newInDims) const { + assert(llvm::all_of(newInDims, [&](auto &inDim) { + return llvm::isPowerOf2_32(inDim.second); + })); + assert(getTotalInDimSize() == std::accumulate(newInDims.begin(), + newInDims.end(), 1, + [&](int32_t acc, auto &inDim) { + return acc * inDim.second; + })); + + // First flatten into a single in-dimension. Then split it up according + // to `newInDims`. + SmallVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + flatBases.push_back(basis); + } + } + + BasesT newBases; + int i = 0; + for (const auto &[inDim, inDimSize] : newInDims) { + auto &newInDimBases = newBases[inDim]; + for (int j = 1; j < inDimSize; j *= 2) { + newInDimBases.push_back(flatBases[i++]); + } + } + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + isSurjective()); +} + +LinearLayout LinearLayout::reshapeOuts( + ArrayRef> newOutDims) const { + assert(llvm::all_of(newOutDims, [&](auto &outDim) { + return llvm::isPowerOf2_32(outDim.second); + })); + assert(getTotalOutDimSize() == + std::accumulate( + newOutDims.begin(), newOutDims.end(), 1, + [&](int32_t acc, auto &outDim) { return acc * outDim.second; })); + + SmallVector shifts; + shifts.push_back(0); + for (StringAttr outDim : getOutDimNames()) { + shifts.push_back(shifts.back() + getOutDimSizeLog2(outDim)); + } + + // Flatten into a single out-dimension. Then split it up according to + // `newOutDims`. + llvm::MapVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &flatInBases = flatBases[inDim]; + for (const auto &basis : inDimBases) { + int b = 0; + for (int i = 0; i < basis.size(); i++) { + b += basis[i] << shifts[i]; + } + flatInBases.push_back(b); + } + } + + BasesT newBases; + for (const auto &[inDim, flatInBases] : flatBases) { + std::vector> &newInDimBases = newBases[inDim]; + for (int32_t b : flatInBases) { + std::vector multiDimBasis; + for (int32_t newSize : llvm::make_second_range(newOutDims)) { + multiDimBasis.push_back(b % newSize); + b /= newSize; + } + newInDimBases.push_back(std::move(multiDimBasis)); + } + } + + return LinearLayout(std::move(newBases), newOutDims, isSurjective()); +} + +LinearLayout LinearLayout::resizeInDim(StringAttr inDim, + int32_t newSize) const { + assert(llvm::isPowerOf2_32(newSize)); + assert(newSize <= getInDimSize(inDim)); + auto newBases = bases; + newBases[inDim].resize(llvm::Log2_32(newSize)); + return LinearLayout(std::move(newBases), getOutDims(), + /*requiresSurjective=*/false); +} + +LinearLayout LinearLayout::resizeOutDim(StringAttr outDim, + int32_t newSize) const { + assert(llvm::isPowerOf2_32(newSize)); + assert(newSize <= getOutDimSize(outDim)); + auto newBases = bases; + // Zero-out the basis vectors that are greater than or equal to the new size + for (auto &[inDim, inDimBases] : newBases) { + for (auto &basis : inDimBases) { + auto &b = basis[getOutDimIndex(outDim)]; + if (b >= newSize) { + b = 0; + } + } + } + auto outDims = getOutDims(); + for (auto &[outDim, outDimSize] : outDims) { + if (outDim == outDim) { + outDimSize = newSize; + } + } + return LinearLayout(std::move(newBases), outDims, + /*requiresSurjective=*/false); +} + +LinearLayout LinearLayout::concatIns(const LinearLayout &other) const { + assert(llvm::to_vector(getOutDimNames()) == + llvm::to_vector(other.getOutDimNames()) && + "layouts must have the same output dimensions"); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) == other.getOutDimSize(outDim) && + "layouts must have the same output dimension sizes"); + } + + LinearLayout::BasesT resultBases = getBases(); + for (auto &bases : other.getBases()) + resultBases.insert(bases); + SmallVector> newOutDims; + for (auto &[outDim, outDimSize] : outDims) + newOutDims.emplace_back(outDim, outDimSize); + return LinearLayout(std::move(resultBases), newOutDims, + /*requiresSurjective=*/false); +} + +LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const { + assert(llvm::to_vector(getInDimNames()) == + llvm::to_vector(other.getInDimNames()) && + "layouts must have the same input dimensions"); + for (StringAttr inDim : getInDimNames()) { + assert(getInDimSize(inDim) == other.getInDimSize(inDim) && + "layouts must have the same input dimension sizes"); + } + + LinearLayout::BasesT result; + for (auto [lhsBases, rhsBases] : llvm::zip(getBases(), other.getBases())) { + auto &resultBases = result[lhsBases.first]; + assert(lhsBases.first == rhsBases.first); + for (auto [lhsBasis, rhsBasis] : + llvm::zip(lhsBases.second, rhsBases.second)) { + std::vector resultBasis; + llvm::append_range(resultBasis, lhsBasis); + llvm::append_range(resultBasis, rhsBasis); + resultBases.push_back(std::move(resultBasis)); + } + } + SmallVector> newOutDims; + for (auto &[outDim, outDimSize] : outDims) + newOutDims.emplace_back(outDim, outDimSize); + for (auto &[outDim, outDimSize] : other.outDims) + newOutDims.emplace_back(outDim, outDimSize); + return LinearLayout(std::move(result), newOutDims, + /*requiresSurjective=*/false); +} + +std::optional divideLeft(const LinearLayout &A, + const LinearLayout &B) { + // Compute a C such that A = B * C if it exists. + // Note that such a C exists iff (every pair of input/output dim of) A is of + // the form + // [[B, 0], + // [0, C]] + // as a matrix, whenever those dimensions are present in B. + for (StringAttr dim : B.getInDimNames()) { + if (!llvm::is_contained(A.getInDimNames(), dim)) + return std::nullopt; + } + for (StringAttr dim : B.getOutDimNames()) { + if (!llvm::is_contained(A.getOutDimNames(), dim)) + return std::nullopt; + } + // Compute candidate C's log-sizes for output dimensions. + llvm::MapVector cOutDimSizes; + for (StringAttr outDim : A.getOutDimNames()) { + int outA = A.getOutDimSizeLog2(outDim); + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int outC = outA - outB; + if (outC < 0) + return std::nullopt; + cOutDimSizes[outDim] = 1 << outC; + } + + LinearLayout::BasesT cBases; + for (StringAttr inDim : A.getInDimNames()) { + int inA = A.getInDimSizeLog2(inDim); + int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0; + int inC = inA - inB; + if (inC < 0) + return std::nullopt; + + std::vector> basesForDim; + // Check that A’s first inB entries agree with B. + for (int i = 0; i < inB; ++i) { + for (StringAttr outDim : A.getOutDimNames()) { + int expected = B.hasOutDim(outDim) ? B.getBasis(inDim, i, outDim) : 0; + int actual = A.getBasis(inDim, i, outDim); + if (actual != expected) + return std::nullopt; + } + } + + // Extract the candidate C bases from the remaining (shifted) entries in A. + for (int i = inB; i < inA; ++i) { + std::vector candidateBasis; + for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) { + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int v = A.getBasis(inDim, i, outDim); + + // The lower outB bits must be zero. + if ((v & ((1 << outB) - 1)) != 0) + return std::nullopt; + candidateBasis.push_back(v >> outB); + } + basesForDim.push_back(std::move(candidateBasis)); + } + cBases[inDim] = basesForDim; + } + + SmallVector> COutDims; + for (auto [outDim, outC] : cOutDimSizes) { + COutDims.push_back({outDim, outC}); + } + // If the layout A and B are surjective, then C should also be surjective. + LinearLayout C(std::move(cBases), COutDims, + /*requireSurjective=*/A.isSurjective() && B.isSurjective()); + assert(B * C == A); + return C; +} + +std::optional divideRight(const LinearLayout &A, + const LinearLayout &B) { + // Compute a C such that A = C * B if it exists. + // Note that such a C exists iff (every pair of input/output dim of) A is of + // the form + // [[C, 0], + // [0, B]] + // as a matrix, whenever those dimensions are present in B. + + // Check that B's in-dimensions and out-dimensions are contained in A. + for (StringAttr dim : B.getInDimNames()) { + if (!llvm::is_contained(A.getInDimNames(), dim)) + return std::nullopt; + } + for (StringAttr dim : B.getOutDimNames()) { + if (!llvm::is_contained(A.getOutDimNames(), dim)) + return std::nullopt; + } + + // Compute candidate C's log-sizes for output dimensions. + llvm::MapVector cOutDimSizes; + for (StringAttr outDim : A.getOutDimNames()) { + int outA = A.getOutDimSizeLog2(outDim); + int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0; + int outC = outA - outB; + if (outC < 0) + return std::nullopt; + cOutDimSizes[outDim] = 1 << outC; + } + + // For candidate C, its in-dim sizes come from subtracting B's in-dim sizes + // from A's. + LinearLayout::BasesT cBases; + for (StringAttr inDim : A.getInDimNames()) { + int inA = A.getInDimSizeLog2(inDim); + int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0; + int inC = inA - inB; + if (inC < 0) + return std::nullopt; + + std::vector> basesForDim; + // The first inC basis vectors come directly from C. + for (int i = 0; i < inC; ++i) { + std::vector candidate; + for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) { + candidate.push_back(A.getBasis(inDim, i, outDim)); + } + basesForDim.push_back(std::move(candidate)); + } + + // The remaining inB basis vectors in A should correspond to B after being + // shifted. + for (int i = inC; i < inA; ++i) { + int j = i - inC; // Index into B's basis vectors for this inDim. + for (StringAttr outDim : B.getOutDimNames()) { + int outA = A.getOutDimSizeLog2(outDim); + int outB = B.getOutDimSizeLog2(outDim); + int outC = outA - outB; // Expected log2 size for C in this output. + int shift = outC; + int v = A.getBasis(inDim, i, outDim); + // The lower shift bits must be zero. + if ((v & ((1 << shift) - 1)) != 0) + return std::nullopt; + int recovered = v >> shift; + int expected = B.getBasis(inDim, j, outDim); + if (recovered != expected) + return std::nullopt; + } + } + cBases[inDim] = basesForDim; + } + + SmallVector> COutDims; + for (auto [outDim, size] : cOutDimSizes) + COutDims.push_back({outDim, size}); + // If A and B are surjective, then C should also be surjective. + LinearLayout C(std::move(cBases), COutDims, + /*requireSurjective=*/A.isSurjective() && B.isSurjective()); + assert(C * B == A); + return C; +} + +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { + // Check that dims common to outer and inner have the same relative order. + auto inDims = supremum(llvm::to_vector(inner.getInDimNames()), + llvm::to_vector(outer.getInDimNames())); + auto outDims = supremum(llvm::to_vector(inner.getOutDimNames()), + llvm::to_vector(outer.getOutDimNames())); + + // Get the sizeLog2 of all input and output dimensions we're going to + // consider, in order. `inner` is more minor, so its dimensions come + // first. + llvm::MapVector inDimSizesLog2; + llvm::MapVector outDimSizesLog2; + for (const auto &dim : inDims) + inDimSizesLog2.insert({dim, 0}); + for (const auto &dim : outDims) + outDimSizesLog2.insert({dim, 0}); + for (const auto &layout : {inner, outer}) { + for (StringAttr inDim : layout.getInDimNames()) { + inDimSizesLog2[inDim] += layout.getInDimSizeLog2(inDim); + } + for (StringAttr outDim : layout.getOutDimNames()) { + outDimSizesLog2[outDim] += layout.getOutDimSizeLog2(outDim); + } + } + + BasesT allBases; + for (auto [inDimName, inDimSizeLog2] : inDimSizesLog2) { + std::vector> &inDimBases = allBases[inDimName]; + + // Fill with zeros. + inDimBases = std::vector>( + inDimSizeLog2, std::vector(outDimSizesLog2.size(), 0)); + + for (auto [outDimIdx, outDimNameAndSize] : + llvm::enumerate(outDimSizesLog2)) { + auto [outDimName, outDimSize] = outDimNameAndSize; + if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { + for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { + inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); + } + } + if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) { + int offset = + inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0; + int shift = inner.hasOutDim(outDimName) + ? inner.getOutDimSizeLog2(outDimName) + : 0; + for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) { + inDimBases[offset + i][outDimIdx] = + outer.getBasis(inDimName, i, outDimName) << shift; + } + } + } + } + + llvm::SmallVector> outDimSizes; + for (auto [outDim, sizeLog2] : outDimSizesLog2) { + outDimSizes.push_back({outDim, 1 << sizeLog2}); + } + return LinearLayout(std::move(allBases), outDimSizes, + inner.isSurjective() && outer.isSurjective()); +} + +bool LinearLayout::isTrivialOver(ArrayRef dimNames) const { + for (StringAttr dim : dimNames) { + if (!llvm::is_contained(getInDimNames(), dim) && + !llvm::is_contained(getOutDimNames(), dim)) { + return false; + } + } + + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + SmallVector remainingInDimNames = + getRemainingDimNames(getInDimNames()); + SmallVector remainingOutDimNames = + getRemainingDimNames(getOutDimNames()); + + // Think of this as a block-matrix multiplying a vector: + // [[A, B], * [v_1, + // [C, D]] v_2] + // where v_2 is the dimNames and v_1 is the remainingInDimNames + // We can quotient out dimNames iff they don't affect the remainingInDimNames + // in the result. In other words, we want to check that B is zero, and C is + // zero, and D is the identity + return squareSublayoutIsIdentity(*this, dimNames) && + sublayoutIsZero(remainingInDimNames, dimNames) && + sublayoutIsZero(dimNames, remainingOutDimNames); +} + +std::optional +LinearLayout::quotient(ArrayRef dimNames) const { + if (!isTrivialOver(dimNames)) { + return std::nullopt; + } + + // This should probably be even less general, where we ask inDimNames == + // outDimNames + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + + SmallVector inDimNames = getRemainingDimNames(getInDimNames()); + SmallVector outDimNames = getRemainingDimNames(getOutDimNames()); + + return sublayout(inDimNames, outDimNames); +} + +LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const { + assertDimsSubsetIgnoringOrder(inDimNames, getInDimNames()); + assertDimsSubsetIgnoringOrder(outDimNames, getOutDimNames()); + SmallDenseSet inDimSet(inDimNames.begin(), inDimNames.end()); + SmallDenseSet outDimSet(outDimNames.begin(), outDimNames.end()); + + SmallVector outDimIndicesToKeep; + for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) { + if (outDimSet.contains(outDim)) { + outDimIndicesToKeep.push_back(i); + } + } + BasesT newBases; + for (auto [inDim, inDimBases] : bases) { + if (!inDimSet.contains(inDim)) { + continue; + } + auto &newInDimBases = newBases[inDim]; + for (auto &basis : inDimBases) { + auto &newBasis = newInDimBases.emplace_back(); + for (int i : outDimIndicesToKeep) { + newBasis.push_back(basis[i]); + } + } + } + + SmallVector> newOutDims; + for (auto [outDim, outDimSize] : outDims) { + if (outDimSet.contains(outDim)) { + newOutDims.push_back({outDim, outDimSize}); + } + } + return LinearLayout(std::move(newBases), std::move(newOutDims), + /*requireSurjective=*/false); +} + +bool LinearLayout::sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const { + LinearLayout ss = sublayout(inDimNames, outDimNames); + for (auto [inDim, inDimBases] : ss.bases) { + for (auto basis : inDimBases) { + if (!llvm::all_of(basis, [](int32_t b) { return b == 0; })) { + return false; + } + } + } + return true; +} + +SmallVector> +LinearLayout::apply(ArrayRef> ins) const { + assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames()); + + SmallVector> ret; + for (StringAttr outDim : getOutDimNames()) { + int32_t outVal = 0; + for (auto &[inDim, val] : ins) { + for (int i = 0; i < getInDimSizeLog2(inDim); i++) { + if (val & (1 << i)) + outVal ^= getBasis(inDim, i, outDim); + } + } + ret.push_back({outDim, outVal}); + } + return ret; +} + +LinearLayout LinearLayout::compose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) <= outer.getInDimSize(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + SmallVector> bases; + for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) { + bases.push_back({outDim, b}); + } + auto newBases = outer.apply(bases); + auto newBasesRange = llvm::make_second_range(newBases); + newInDimBases.push_back( + std::vector(newBasesRange.begin(), newBasesRange.end())); + } + } + + bool compositionIsSurjective = + isSurjective() && outer.isSurjective() && + llvm::all_of(getOutDimNames(), [&](StringAttr outDim) { + return getOutDimSize(outDim) == outer.getInDimSize(outDim); + }); + return LinearLayout(std::move(newBases), llvm::to_vector(outer.outDims), + compositionIsSurjective); +} + +namespace { +std::unique_ptr concatMatrices(const LinearLayout &A, + const LinearLayout &B) { + // conv + assert(A.getTotalOutDimSizeLog2() >= B.getTotalOutDimSizeLog2() && + "A must have at least as many output bits as B"); + int numColsA = A.getTotalInDimSizeLog2(); + + // rref expects the lower bits to be the lower indices of the matrix + auto concat = getMatrix(A); + auto BMat = getMatrix(B); + int rowA = 0; + int rowB = 0; + for (auto [outDim, outDimSize] : A.getOutDims()) { + for (int r = 0; r < llvm::Log2_32(outDimSize); r++) { + if (r < llvm::Log2_32(B.getOutDimSize(outDim))) { + concat[rowA] |= BMat[rowB] << numColsA; + rowB++; + } + rowA++; + } + } + return concat; +} + +LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { + // Solve the least square system AX = B + // and return the least square solution X by computing RREF and setting + // the free variables to zero. + // A and B may not be surjective, but we assume that Im(B) \subset Im(A) + // Sketch of the algorithm: + // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 + int numRows = A.getTotalOutDimSizeLog2(); + assert(numRows >= B.getTotalOutDimSizeLog2() && + "A.lstsq(B) called with incompatible output shapes"); + int numColsA = A.getTotalInDimSizeLog2(); + int numColsB = B.getTotalInDimSizeLog2(); + int numCols = numColsA + numColsB; + std::unique_ptr combinedMat = concatMatrices(A, B); + f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, + /*stride=*/1); + + // Compute the pivot columns + // Since A and B have the same image, each row will either have a pivot + // or will be all zeros + SmallVector pivotRowOfCol(numColsA, -1); + for (int r = 0; r < numRows; r++) { + auto row = combinedMat[r]; + if (row == 0) { + continue; + } + int c = __builtin_ctzll(row); + assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); + assert(pivotRowOfCol[c] == -1 && + "duplicate pivot => matrix not in RREF or A not injective"); + pivotRowOfCol[c] = r; + } + + // Extract A^{-1}B and complete the matrix using zeros + std::unique_ptr retMat(new uint64_t[numColsA]()); + for (int c = 0; c < numColsA; ++c) { + int row = pivotRowOfCol[c]; + retMat[c] = (row == -1) ? 0 : (combinedMat[row] >> numColsA); + } + + // We need names for the in/out dim of the flattened layout we're going to + // read off from `m`. These could be anything, doesn't matter. + assert(!A.getInDimNames().empty() && + "attempt to solve lstsq for empty layout"); + StringAttr inDim1D = *A.getInDimNames().begin(); + StringAttr outDim1D = *A.getOutDimNames().begin(); + + // Read off the new bases. These are for a flattened 1D -> 1D + LinearLayout::BasesT retBases; + auto &bs = retBases[inDim1D]; + for (int c = 0; c < numColsB; c++) { + int32_t basis = 0; + for (int r = 0; r < numColsA; r++) { + basis |= (retMat[r] >> c & 1) << r; + } + bs.push_back({basis}); + } + + LinearLayout retFlattened(std::move(retBases), + {{outDim1D, A.getTotalInDimSize()}}, + /*requireSurjective=*/false); + + SmallVector> retInDims; + SmallVector> retOutDims; + for (StringAttr dim : B.getInDimNames()) { + retInDims.push_back({dim, B.getInDimSize(dim)}); + } + for (StringAttr dim : A.getInDimNames()) { + retOutDims.push_back({dim, A.getInDimSize(dim)}); + } + return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +} + +} // namespace + +LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { + // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` + // For this, we need to implement our LLVM lowerings by inverting the "outer" + // layout, and then iterating over the elements from the "this" layout and + // fetching the corresponding element from the "outer" layout. This exercises + // the broadcasting that we incentivise via choosing the minimum norm solution + // in lstsq. + + // The order of dims does not matter. We choose to transpose outer + auto outDims = llvm::to_vector(getOutDimNames()); + assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); + const auto &B = *this; + const auto A = outer.transposeOuts(outDims); + for (auto dim : outDims) { + assert(A.getOutDimSize(dim) >= B.getOutDimSize(dim) && + ("A.invertAndCompose(B) called with incompatible output shapes in " + + dim.str() + ": " + std::to_string(A.getOutDimSize(dim)) + + " >= " + std::to_string(B.getOutDimSize(dim))) + .c_str()); + } + + // Broadcasting heuristic + // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` + // (broadcasting) on both layouts. We could map any warp to any warp in the + // conversion. Now, we want to map them as the identity map, to mark that + // nothing needs to be done there (`lstsq` would map all the warps to the + // zero warp, minimum norm solution). The heuristic here is as follows: + // - If a dimension is the same for both layouts, we want to map it as the + // identity + // Equivalently, we don't add it to the conversion + // - Otherwise, we just call lstsq (i.e. map all the equivalent elements + // to the same input element) to take advantage of broadcasting in shared + // memory and avoid saving repeated elements in shared memory + + // FIXME: We should check that the other dimensions don't touch the image of + // this dimension. + SmallVector identityDims; + for (auto dim : A.getInDimNames()) { + if (B.hasInDim(dim) && + A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { + identityDims.push_back(dim); + } + } + SmallVector ANonIdentityInDims; + SmallVector BNonIdentityInDims; + for (auto dim : A.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + ANonIdentityInDims.push_back(dim); + } + } + for (auto dim : B.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + BNonIdentityInDims.push_back(dim); + } + } + + auto AReduced = A.sublayout(ANonIdentityInDims, outDims); + auto BReduced = B.sublayout(BNonIdentityInDims, outDims); + + // If one is empty, the other must be empty as well + assert((ANonIdentityInDims.empty()) == (BNonIdentityInDims.empty())); + bool isEmpty = ANonIdentityInDims.empty(); + + auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); + + // TODO(Lezcano): We should return the reduced layout instead of re-adding the + // identity maps. With this, we'll be able to kill `minimalCvtLayout` + + // Add the identity maps for the dimensions that are the same for both layouts + for (auto dim : identityDims) { + ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); + } + + // Reorder the dimensions in the result to match the order expected by the + // current and outer layouts. + return ret.transposeIns(llvm::to_vector(B.getInDimNames())) + .transposeOuts(llvm::to_vector(A.getInDimNames())); +} + +LinearLayout LinearLayout::invert() const { + assert(isInvertible() && + "A linear layout must be surjective and square to be invertible"); + return pseudoinvert(); +} + +LinearLayout LinearLayout::pseudoinvert() const { + LinearLayout identity = LinearLayout::empty(); + for (auto outDim : getOutDimNames()) { + identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim); + } + return identity.invertAndCompose(*this); +} + +LinearLayout LinearLayout::unsqueezeIn(StringAttr dim) const { + assert(getInDimSize(dim) == 1); + SmallVector> newInDims; + for (auto inDim : getInDimNames()) { + if (inDim != dim) { + newInDims.push_back({inDim, getInDimSize(inDim)}); + } + } + return reshapeIns(newInDims); +} + +LinearLayout LinearLayout::unsqueezeOut(StringAttr dim) const { + assert(getOutDimSize(dim) == 1); + SmallVector> newOutDims; + for (auto [outDim, outDimSize] : getOutDims()) { + if (outDim != dim) { + newOutDims.push_back({outDim, outDimSize}); + } + } + return LinearLayout(bases, newOutDims, isSurjective()); +} + +llvm::MapVector +LinearLayout::getFreeVariableMasks() const { + std::unique_ptr mat = getMatrix(*this); + int numRows = getTotalOutDimSizeLog2(); + int numCols = getTotalInDimSizeLog2(); + + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); + f2reduce::inplace_rref_strided(mat.get(), numRows, numCols, /*stride=*/1); + + // For each row in the RREF matrix, identify the column with the first "1". + // These columns correspond to the basic (i.e. non-free) variables. + std::set basicVars; + for (int r = 0; r < numRows; r++) { + if (mat[r] == 0) { + continue; + } + basicVars.insert(__builtin_ctzll(mat[r])); + } + + llvm::MapVector ret; + int c = 0; + for (StringAttr dim : getInDimNames()) { + int32_t mask = 0; + for (int i = 0; i < getInDimSizeLog2(dim); i++, c++) { + if (basicVars.count(c) == 0) { + mask |= (1 << i); + } + } + ret[dim] = mask; + } + return ret; +} + +LinearLayout LinearLayout::removeZeroBasesAlongDim(StringAttr stripDim) const { + LinearLayout::BasesT result; + for (auto &[inDim, inDimBases] : getBases()) { + auto &newInDimBases = result[inDim]; + if (inDim != stripDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + SmallVector> newOutDimSizes; + for (auto outDim : getOutDimNames()) { + newOutDimSizes.push_back({outDim, getOutDimSize(outDim)}); + } + auto newLayout = LinearLayout(std::move(result), ArrayRef(newOutDimSizes), + this->isSurjective()); + return newLayout; +} + +size_t hash_value(const LinearLayout &layout) { + size_t seed = 0; + + // Hash the bases + for (const auto &base : layout.getBases()) { + // Hash the input dimension name + seed = llvm::hash_combine(seed, base.first); + + // Hash the vectors in bases + for (const auto &vec : base.second) { + for (int32_t val : vec) { + seed = llvm::hash_combine(seed, val); + } + } + } + + // Hash the output dimensions and their sizes + for (const auto &outDim : layout.getOutDimNames()) { + seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); + } + // Don't hash the surjective flag as it's a cached property + return seed; +} + +bool operator==(const LinearLayout &lhs, const LinearLayout &rhs) { + if (!lhs.equalIgnoringOutDimSizes(rhs)) + return false; + + for (const auto &[lhsOutDimAndSize, rhsOutDimAndSize] : + llvm::zip(lhs.outDims, rhs.outDims)) { + if (lhsOutDimAndSize.second != rhsOutDimAndSize.second) + return false; + } + return true; +} + +bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { + // llvm::MapVector doesn't have an operator== :(. + if (llvm::to_vector(this->getOutDimNames()) != + llvm::to_vector(other.getOutDimNames())) + return false; + if (this->bases.size() != other.bases.size()) + return false; + for (auto it1 = this->bases.begin(), it2 = other.bases.begin(); + it1 != this->bases.end(); ++it1, ++it2) { + if (*it1 != *it2) + return false; + } + return true; +} + +std::string LinearLayout::toString() const { + // Start with a newline because we print out a bulleted list; it doesn't + // make sense for the first line of this list to be on the same line as + // any previous text. + std::string ret = "\n"; + std::string outDimsStr = + "[" + + join(outDims, ", ", + [](auto dimAndSize) { + auto [outDim, size] = dimAndSize; + return outDim.str() + " (size " + std::to_string(size) + ")"; + }) + + "]"; + + if (bases.empty()) { + if (outDims.empty()) { + return "\n(empty layout)"; + } else { + return "\n(empty layout with out-dims " + outDimsStr + ")"; + } + } + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: " + outDimsStr; + return ret; +} + +LinearLayout ColumnAction::apply(const LinearLayout &layout) const { + assert(layout.hasInDim(inDim)); + assert(layout.getInDimSizeLog2(inDim) == inSizeLog2 && + "Layout has a different size than the ColumnAction"); + if (m_isIdentity) { + return layout; + } + + auto bases = layout.getBases(); + const auto &basesInDim = bases[inDim]; + std::vector> newBases; + newBases.reserve(action.size()); + for (size_t a : action) { + newBases.push_back(basesInDim[a]); + } + bases[inDim] = std::move(newBases); + + SmallVector> outDims; + for (auto outDim : layout.getOutDimNames()) { + outDims.emplace_back(outDim, layout.getOutDimSize(outDim)); + } + return LinearLayout(std::move(bases), std::move(outDims), + /*requireSurjective=*/false); +} + +SmallVector ColumnAction::apply(ValueRange values) const { + assert(values.size() == (1 << inSizeLog2) && + "Values have a different size than the ColumnAction"); + assert(inDim.str() == "register" && "Values are in registers, so we can only " + "apply ColumnAction to registers"); + if (m_isIdentity) { + return values; + } + auto permLL = apply(LinearLayout::identity1D(values.size(), inDim, inDim)); + SmallVector ret; + ret.reserve(permLL.getInDimSize(inDim)); + for (int i = 0; i < permLL.getInDimSize(inDim); i++) { + int32_t srcIdx = permLL.apply({{inDim, i}}).begin()->second; + ret.push_back(values[srcIdx]); + } + return ret; +} + +ColumnAction ColumnAction::leftCompose(const ColumnAction &other) const { + assert(inDim == other.inDim); + assert(inSizeLog2 == other.inSizeLog2); + assert(action.size() == other.action.size()); + auto newAction = SmallVector(action.size()); + for (size_t i = 0; i < action.size(); i++) { + newAction[i] = action[other.action[i]]; + } + return ColumnAction(newAction, inDim, inSizeLog2); +} + +ColumnAction ColumnAction::inverse() const { + auto invPerm = SmallVector(action.size()); + for (size_t i = 0; i < action.size(); i++) { + invPerm[action[i]] = i; + } + return ColumnAction(invPerm, inDim, inSizeLog2); +} + +std::string ColumnAction::toString() const { + std::string ret = "ColumnAction(["; + ret += join(action, ", "); + ret += "], " + inDim.str() + ", " + std::to_string(inSizeLog2) + ")"; + return ret; +} + +// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing +// the bases of the given layout. This can then be used by f2reduce. +// +// This function is called from the constructor of LinearLayout, so be careful +// not to use any functions that create LLs in here. +std::unique_ptr getMatrix(const LinearLayout &layout) { + int numRows = layout.getTotalOutDimSizeLog2(); + int numCols = layout.getTotalInDimSizeLog2(); + + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Suppose we have a layout specified by the following values. + // + // L(0,1) = (0b01, 0b1) + // L(0,2) = (0b10, 0b0) + // L(1,0) = (0b10, 0b0) + // L(2,0) = (0b11, 0b0) + // + // We will create one column per entry above. The max bit width of the + // codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix + // will be + // + // | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 | + // | ↓ ↓ ↓ ↓ | | 0b0111 | + // | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 | + // | ↓ ↓ ↓ ↓ | + // + // Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + int r = 0; + for (StringAttr outDim : layout.getOutDimNames()) { + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) { + m[r + j] |= ((basis >> j) & 1) << c; + } + c++; + } + } + r += layout.getOutDimSizeLog2(outDim); + } + + return m; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Tools/PluginUtils.cpp b/third_party/mthreads/lib/Tools/PluginUtils.cpp new file mode 100644 index 0000000000..6434990057 --- /dev/null +++ b/third_party/mthreads/lib/Tools/PluginUtils.cpp @@ -0,0 +1,162 @@ +#include "triton/Tools/PluginUtils.h" + +llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { + if (!library.isValid()) { + auto msg = llvm::Twine("Failed to load plugin library: " + error + "\n"); + return llvm::createStringError(msg); + } + return llvm::Error::success(); +} + +llvm::Expected +TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { + if (auto isValid = checkLibraryValid("not loaded")) + return isValid; + intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); + if (!getDetailsFn) { + auto msg = llvm::Twine("Failed to get symbol: " + symbol + "\n"); + return llvm::createStringError(msg); + } + return getDetailsFn; +} + +llvm::Expected +TritonPlugin::checkAPIResult(TritonPluginResult result, + const char *handle) const { + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to add/register plugin pass (" << handle + << ") to pass manager, error code: " << result; + return llvm::createStringError(msg); +} + +std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { + std::string msg; + llvm::raw_string_ostream os(msg); + os << Err; + return std::runtime_error(msg); +} + +llvm::Error TritonPlugin::loadPlugin() { + if (isLoaded) + return llvm::Error::success(); + + std::string error; + library = + llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); + if (auto isValid = checkLibraryValid(error)) + return isValid; + + if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES)) { + auto enumeratePassesAPIOrErr = + getAPI( + ENUMERATE_PASSES); + auto addPassAPIOrErr = getAPI(ADD_PASS); + auto registerPassAPIOrErr = + getAPI(REGISTER_PASS); + + if (auto Err = enumeratePassesAPIOrErr.takeError()) + return Err; + if (auto Err = addPassAPIOrErr.takeError()) + return Err; + if (auto Err = registerPassAPIOrErr.takeError()) + return Err; + + addPassAPI = *addPassAPIOrErr; + registerPassAPI = *registerPassAPIOrErr; + enumeratePassesAPI = *enumeratePassesAPIOrErr; + } + + if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS)) { + auto enumerateDialectsAPIOrErr = + getAPI( + ENUMERATE_DIALECTS); + auto dialectPluginInfoAPIOrErr = + getAPI( + DIALECT_PLUGININFO); + + if (auto Err = enumerateDialectsAPIOrErr.takeError()) + return Err; + if (auto Err = dialectPluginInfoAPIOrErr.takeError()) + return Err; + enumerateDialectsAPI = *enumerateDialectsAPIOrErr; + dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr; + } + + isLoaded = true; + return llvm::Error::success(); +} + +llvm::Expected TritonPlugin::enumeratePyBindHandles( + EnumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &handles) { + if (auto Err = loadPlugin()) + return Err; + + uint32_t passCount = 0; + handles.clear(); + auto result = enumeratePyBindHandles(&passCount, nullptr); + if (result == TP_SUCCESS) { + if (passCount == 0) + return TP_SUCCESS; + + handles.resize(passCount); + result = enumeratePyBindHandles(&passCount, handles.data()); + } + + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to retrive plugin pass handles, error code: " << result; + return llvm::createStringError(msg); +} + +llvm::Expected +TritonPlugin::getPassHandles(std::vector &passNames) { + if (auto Err = loadPlugin()) + return Err; + // Do a check to see if the enumerate-passes api symbol is present, bail as + // if there are 0 passes if not + intptr_t isPassPluginSymbolPresent = + (intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES); + if (!isPassPluginSymbolPresent) + return TP_SUCCESS; + return enumeratePyBindHandles(enumeratePassesAPI, passNames); +} + +llvm::Expected +TritonPlugin::getDialectHandles(std::vector &dialectNames) { + if (auto Err = loadPlugin()) + return Err; + // Do a check to see if the enumerate-dialects api symbol is present, bail as + // if there are 0 dialects if not + intptr_t isDialectPluginSymbolPresent = + (intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS); + if (!isDialectPluginSymbolPresent) + return TP_SUCCESS; + return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames); +} + +llvm::Expected +TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(addPassAPI(pm, passHandle), passHandle); +} + +llvm::Expected +TritonPlugin::registerPass(const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(registerPassAPI(passHandle), passHandle); +} + +llvm::Expected<::mlir::DialectPluginLibraryInfo> +TritonPlugin::getDialectPluginInfo(const char *dialectName) { + if (auto Err = loadPlugin()) + return Err; + return dialectPluginInfoAPI(dialectName); +} diff --git a/third_party/mthreads/musa/CMakeLists.txt b/third_party/mthreads/musa/CMakeLists.txt new file mode 100644 index 0000000000..8a43d93a8b --- /dev/null +++ b/third_party/mthreads/musa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(include) +add_subdirectory(lib) diff --git a/third_party/mthreads/musa/include/CMakeLists.txt b/third_party/mthreads/musa/include/CMakeLists.txt new file mode 100644 index 0000000000..fa7e5fb35e --- /dev/null +++ b/third_party/mthreads/musa/include/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Dialect) +add_subdirectory(MTGPUToLLVM) +add_subdirectory(TritonMUSAGPUToLLVM) +add_subdirectory(TritonMUSAGPUTransforms) diff --git a/third_party/mthreads/musa/include/Dialect/CMakeLists.txt b/third_party/mthreads/musa/include/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..4e7e1ea71d --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(MTGPU) +add_subdirectory(MUSA) diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/CMakeLists.txt b/third_party/mthreads/musa/include/Dialect/MTGPU/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/IR/CMakeLists.txt b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..4647bfabd2 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS MTGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=mtgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=mtgpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(MTGPUDialect MTGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(MTGPUOps MTGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(MTGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS MTGPUTypes.td) +mlir_tablegen(MTGPUTypes.h.inc -gen-typedef-decls -typedefs-dialect=mtgpu) +mlir_tablegen(MTGPUTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=mtgpu) +add_public_tablegen_target(MTGPUTypesIncGen) diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/IR/Dialect.h b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/Dialect.h new file mode 100644 index 0000000000..eb1a6ed0dc --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/Dialect.h @@ -0,0 +1,28 @@ +#ifndef TRITON_DIALECT_MTGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_MTGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "llvm/ADT/StringRef.h" + +// clang-format off +#include "Dialect/MTGPU/IR/Dialect.h.inc" +#include "Dialect/MTGPU/IR/OpsEnums.h.inc" +// clang-format on + +#define GET_TYPEDEF_CLASSES +#include "Dialect/MTGPU/IR/MTGPUTypes.h.inc" + +#define GET_OP_CLASSES +#include "Dialect/MTGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace mtgpu {} // namespace mtgpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUDialect.td b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUDialect.td new file mode 100644 index 0000000000..b8228e9dce --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUDialect.td @@ -0,0 +1,21 @@ +#ifndef MTGPU_DIALECT +#define MTGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def MTGPUDialect : Dialect { + let name = "mtgpu"; + let cppNamespace = "::mlir::triton::mtgpu"; + let useDefaultTypePrinterParser = 1; + + let description = [{ + MUSA backend-private low-level hardware and ABI dialect. + }]; + + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + "mlir::LLVM::LLVMDialect" + ]; +} + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUOps.td b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUOps.td new file mode 100644 index 0000000000..38174fb1f5 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUOps.td @@ -0,0 +1,125 @@ +#ifndef MTGPU_OPS +#define MTGPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "MTGPUDialect.td" +include "MTGPUTypes.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" + +class MTGPU_Op traits = []> + : LLVM_OpBase; + +def MTGPU_SQMMA_LayoutAttr : I32EnumAttr<"SQMMALayout", + "sqmma layout, either 'row' or 'col'", + [ + I32EnumAttrCase<"row", 0>, + I32EnumAttrCase<"col", 1> + ]> { + let cppNamespace = "::mlir::triton::mtgpu"; +} + +def MTGPU_SQMMA_EltTypeAttr : I32EnumAttr<"SQMMAEltType", + "sqmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'", + [ + I32EnumAttrCase<"s8", 0>, + I32EnumAttrCase<"s32", 1>, + I32EnumAttrCase<"e4m3", 2>, + I32EnumAttrCase<"e5m2", 3>, + I32EnumAttrCase<"f16", 4>, + I32EnumAttrCase<"bf16", 5>, + I32EnumAttrCase<"tf32", 6>, + I32EnumAttrCase<"f32", 7> + ]> { + let cppNamespace = "::mlir::triton::mtgpu"; +} + +def MTGPU_SQMMA_AccumulationModeAttr : I32EnumAttr<"SQMMAAccumulationMode", + "sqmma accumulation contract", + [ + I32EnumAttrCase<"hardware", 0>, + I32EnumAttrCase<"partial", 1>, + I32EnumAttrCase<"software", 2> + ]> { + let cppNamespace = "::mlir::triton::mtgpu"; +} + +def MTGPU_SQMMATensor : Type< + And<[CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">, + CPred<"::mlir::isa_and_nonnull<::mlir::triton::gpu::MUSASqmmaEncodingAttr>(::llvm::cast<::mlir::RankedTensorType>($_self).getEncoding())">]>, + "MUSA SQMMA accumulator tensor">; +def MTGPU_SQMMACarrierType : Type< + CPred<"::llvm::isa<::mlir::triton::mtgpu::SqmmaAccumulatorType>($_self)">, + "MUSA SQMMA accumulator carrier">; +def MTGPU_SQMMACarrierOrTensorOrMemDesc : AnyTypeOf< + [MTGPU_SQMMACarrierType, TTG_TensorOrMemDesc], + "MUSA SQMMA carrier/tensor/memdesc wait value">; + +def MTGPU_SqmmaOp : MTGPU_Op<"sqmma", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self"> + ]> { + let arguments = (ins TTG_TensorOrMemDesc:$a, TTG_TensorOrMemDesc:$b, + MTGPU_SQMMACarrierType:$c, Optional:$useC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + MTGPU_SQMMA_EltTypeAttr:$eltTypeC, + MTGPU_SQMMA_EltTypeAttr:$eltTypeA, + MTGPU_SQMMA_EltTypeAttr:$eltTypeB, + MTGPU_SQMMA_LayoutAttr:$layoutA, + MTGPU_SQMMA_LayoutAttr:$layoutB, + DefaultValuedAttr:$isAsync, + DefaultValuedAttr:$accMode, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc); + let results = (outs MTGPU_SQMMACarrierType:$d); + let assemblyFormat = + "$a `,` $b `,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + bool usesSoftwareAccumulator(); + bool usesHardwareAccumulator(); + }]; + let hasVerifier = 1; +} + +def MTGPU_PackSqmmaAccumulatorOp : MTGPU_Op<"pack_sqmma_accumulator", [Pure]> { + let summary = "Pack a SQMMA accumulator tensor into compact carrier form"; + let arguments = (ins MTGPU_SQMMATensor:$input); + let results = (outs MTGPU_SQMMACarrierType:$carrier); + let assemblyFormat = + "$input attr-dict `:` type($input) `->` type($carrier)"; + let hasVerifier = 1; +} + +def MTGPU_UnpackSqmmaAccumulatorOp : MTGPU_Op<"unpack_sqmma_accumulator", [Pure]> { + let summary = "Unpack a SQMMA accumulator carrier back to tensor form"; + let arguments = (ins MTGPU_SQMMACarrierType:$carrier); + let results = (outs MTGPU_SQMMATensor:$output); + let assemblyFormat = + "$carrier attr-dict `:` type($carrier) `->` type($output)"; + let hasVerifier = 1; +} + +def MTGPU_SqmmaWaitOp : MTGPU_Op<"sqmma_wait", [ + DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>, + PassthroughWaitLike, + MemoryEffects<[MemRead, MemWrite]> + ]> { + let summary = "Wait until pending SQMMA operations complete."; + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; + let hasCanonicalizeMethod = 1; + let hasVerifier = 1; +} + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUTypes.td b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUTypes.td new file mode 100644 index 0000000000..5c010c54e4 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MTGPU/IR/MTGPUTypes.td @@ -0,0 +1,32 @@ +#ifndef MTGPU_TYPES +#define MTGPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "MTGPUDialect.td" + +class MTGPU_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def MTGPU_SqmmaAccumulatorType + : MTGPU_TypeDef<"SqmmaAccumulator", "sqmma_accumulator"> { + let summary = "backend-private SQMMA accumulator carrier"; + let description = [{ + A backend-private carrier type that preserves SQMMA loop-carried + accumulators in compact fragment form while keeping the wrapped tensor + type as the logical accumulator semantics. + }]; + + let parameters = (ins "Type":$tensorType); + let assemblyFormat = "`<` $tensorType `>`"; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + RankedTensorType getAccumulatorType() const { + return llvm::cast(getTensorType()); + } + }]; +} + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/CMakeLists.txt b/third_party/mthreads/musa/include/Dialect/MUSA/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/IR/CMakeLists.txt b/third_party/mthreads/musa/include/Dialect/MUSA/IR/CMakeLists.txt new file mode 100644 index 0000000000..82158813a0 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS MUSAOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttmg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttmg) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(MUSADialect MUSADialect dialects/ -gen-dialect-doc) +add_mlir_doc(MUSAOps MUSAOps dialects/ -gen-op-doc) +add_public_tablegen_target(MUSATableGen) + +set(LLVM_TARGET_DEFINITIONS MUSAAttrDefs.td) +mlir_tablegen(MUSAAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(MUSAAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MUSAAttrDefsIncGen) diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/IR/Dialect.h b/third_party/mthreads/musa/include/Dialect/MUSA/IR/Dialect.h new file mode 100644 index 0000000000..ae1fd894e7 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_MUSA_IR_DIALECT_H_ +#define TRITON_DIALECT_MUSA_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "llvm/ADT/StringRef.h" + +// clang-format off +#include "Dialect/MUSA/IR/Dialect.h.inc" +#include "Dialect/MUSA/IR/OpsEnums.h.inc" +// clang-format on + +#define GET_ATTRDEF_CLASSES +#include "Dialect/MUSA/IR/MUSAAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "Dialect/MUSA/IR/Ops.h.inc" + +#endif // TRITON_DIALECT_MUSA_IR_DIALECT_H_ diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAAttrDefs.td b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAAttrDefs.td new file mode 100644 index 0000000000..dc066a2202 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAAttrDefs.td @@ -0,0 +1,78 @@ +#ifndef MUSA_ATTRDEFS +#define MUSA_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "MUSADialect.td" + +class MUSA_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +def TME_SwizzleGranularityAttr : I32EnumAttr<"TMESwizzleGranularity", + "TME swizzle granularity", + [ + I32EnumAttrCase<"SG_NONE", 0>, + I32EnumAttrCase<"SG_16B", 1>, + I32EnumAttrCase<"SG_32B", 2>, + I32EnumAttrCase<"SG_64B", 3>, + I32EnumAttrCase<"SG_128B", 4> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_SwizzleStrideAttr : I32EnumAttr<"TMESwizzleStride", + "TME swizzle stride", + [ + I32EnumAttrCase<"SS_32B", 0>, + I32EnumAttrCase<"SS_64B", 1>, + I32EnumAttrCase<"SS_128B", 2>, + I32EnumAttrCase<"SS_256B", 3> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_SwizzleLineAttr : I32EnumAttr<"TMESwizzleLine", + "TME swizzle line", + [ + I32EnumAttrCase<"SL_128B", 0>, + I32EnumAttrCase<"SL_256B", 1> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_PrefetchSizeAttr : I32EnumAttr<"TMEPrefetchSize", + "TME prefetch size", + [ + I32EnumAttrCase<"SZ_NONE", 0>, + I32EnumAttrCase<"SZ_64B", 64>, + I32EnumAttrCase<"SZ_128B", 128>, + I32EnumAttrCase<"SZ_256B", 256> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_PersistenceAttr : I32EnumAttr<"TMEPersistence", + "TME persistence policy", + [ + I32EnumAttrCase<"CACHE_NONE", 0>, + I32EnumAttrCase<"CACHE_ONCE", 1>, + I32EnumAttrCase<"CACHE_NORMAL", 2>, + I32EnumAttrCase<"CACHE_PERSIST", 3>, + I32EnumAttrCase<"NO_OVERRIDE", 4>, + I32EnumAttrCase<"LAST_USE", 5> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_CachePolicyAttr : I32EnumAttr<"TMEL2CachePolicy", + "TME L2 cache policy", + [ + I32EnumAttrCase<"NEW_ALLOC", 0>, + I32EnumAttrCase<"BYPASS", 1> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSADialect.td b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSADialect.td new file mode 100644 index 0000000000..e4eadad024 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSADialect.td @@ -0,0 +1,19 @@ +#ifndef MUSA_DIALECT +#define MUSA_DIALECT + +include "mlir/IR/OpBase.td" + +def MUSA_Dialect : Dialect { + let name = "ttmg"; + let cppNamespace = "::mlir::triton::musa"; + + let description = [{ + Triton MUSA GPU middle dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect" + ]; +} + +#endif diff --git a/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAOps.td b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAOps.td new file mode 100644 index 0000000000..6d6502d966 --- /dev/null +++ b/third_party/mthreads/musa/include/Dialect/MUSA/IR/MUSAOps.td @@ -0,0 +1,243 @@ +#ifndef MUSA_OPS +#define MUSA_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "MUSADialect.td" +include "MUSAAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" + +class MUSA_Op traits = []> : + LLVM_OpBase; + +def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; + +def SQMMA_LayoutAttr : I32EnumAttr<"SQMMALayout", + "sqmma layout, either 'row' or 'col'", + [ + I32EnumAttrCase<"row", 0>, + I32EnumAttrCase<"col", 1> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def SQMMA_EltTypeAttr : I32EnumAttr<"SQMMAEltType", + "sqmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'", + [ + I32EnumAttrCase<"s8", 0>, + I32EnumAttrCase<"s32", 1>, + I32EnumAttrCase<"e4m3", 2>, + I32EnumAttrCase<"e5m2", 3>, + I32EnumAttrCase<"f16", 4>, + I32EnumAttrCase<"bf16", 5>, + I32EnumAttrCase<"tf32", 6>, + I32EnumAttrCase<"f32", 7> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def SQMMA_AccumulationModeAttr : I32EnumAttr<"SQMMAAccumulationMode", + "sqmma accumulation contract", + [ + I32EnumAttrCase<"hardware", 0>, + I32EnumAttrCase<"partial", 1>, + I32EnumAttrCase<"software", 2> + ]> { + let cppNamespace = "::mlir::triton::musa"; +} + +def TME_DimType : AnyTypeOf<[I32, LLVM_AnyVector], "TME block dim/pos type">; +def TME_DescType : AnyTypeOf<[I64, TT_TensorDescType], + "TME descriptor handle or tensor descriptor">; +def TME_LocalMemType : AnyTypeOf<[LLVM_PointerShared, TTG_MemDescType], + "shared memory pointer or memdesc for TME async copy">; + +def MUSA_SquadDotOp : MUSA_Op<"squad_dot", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self"> + ]> { + let arguments = (ins TTG_TensorOrMemDesc:$a, TTG_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, Optional:$useC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + SQMMA_EltTypeAttr:$eltTypeC, SQMMA_EltTypeAttr:$eltTypeA, + SQMMA_EltTypeAttr:$eltTypeB, + SQMMA_LayoutAttr:$layoutA, SQMMA_LayoutAttr:$layoutB, + DefaultValuedAttr:$isAsync, + DefaultValuedAttr:$accMode, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc); + let results = (outs TT_FpIntTensor:$d); + let assemblyFormat = + "$a `,` $b `,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + bool usesSoftwareAccumulator(); + bool usesHardwareAccumulator(); + }]; + let hasVerifier = 1; +} + +def MUSA_SquadDotWaitOp : MUSA_Op<"squad_dot_wait", [ + DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>, + PassthroughWaitLike, + MemoryEffects<[MemRead, MemWrite]> + ]> { + let summary = "Wait until pending SQMMA operations complete."; + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; + let hasCanonicalizeMethod = 1; + let hasVerifier = 1; +} + +def MUSA_WmmaDotOp : MUSA_Op<"wmma_dot", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self"> + ]> { + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, + TT_FpIntTensor:$c, Optional:$useC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + SQMMA_EltTypeAttr:$eltTypeA, SQMMA_EltTypeAttr:$eltTypeB, + SQMMA_LayoutAttr:$layoutA, SQMMA_LayoutAttr:$layoutB, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc); + let results = (outs TT_FpIntTensor:$d); + let assemblyFormat = + "$a `,` $b `,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + }]; + let hasVerifier = 1; +} + +def MUSA_WmmaDotWaitOp : MUSA_Op<"wmma_dot_wait", []> { + let summary = "IR no-op marker for WMMA dependency edges."; + let assemblyFormat = "attr-dict"; +} + +def MUSA_BarRecordOp : MUSA_Op<"bar_record", []> { + let summary = "Record an async barrier id before use."; + let arguments = (ins I32:$barId); + let assemblyFormat = "$barId attr-dict `:` type($barId)"; + let hasVerifier = 1; +} + +def MUSA_InitArrivalOp : MUSA_Op<"init_arrival", []> { + let summary = "Initialize async barrier arrival count and phase id."; + let arguments = (ins I32:$barId, I32:$arriveCount, I32:$phaseId); + let assemblyFormat = + "$barId `,` $arriveCount `,` $phaseId attr-dict `:` type($barId)"; + let hasVerifier = 1; +} + +def MUSA_BarrierAddTransOp : MUSA_Op<"barrier_add_trans", []> { + let summary = "Explicitly add async barrier transaction bytes."; + let arguments = (ins I32:$barId, I32:$transBytes, I1:$pred); + let assemblyFormat = + "$barId `,` $transBytes `,` $pred attr-dict `:` type($barId)"; + let hasVerifier = 1; +} + +def MUSA_ArriveBarrierOp : MUSA_Op<"arrive_barrier", []> { + let summary = "Arrive on async barrier and return phase id."; + let arguments = (ins I32:$barId); + let results = (outs I32:$phaseId); + let assemblyFormat = "$barId attr-dict `:` type($barId) `->` type($phaseId)"; + let hasVerifier = 1; +} + +def MUSA_ArriveBarrierNoRetOp : MUSA_Op<"arrive_barrier_noret", []> { + let summary = "Arrive on async barrier without returning phase id."; + let arguments = (ins I32:$barId, I1:$pred); + let assemblyFormat = "$barId `,` $pred attr-dict `:` type($barId)"; + let hasVerifier = 1; +} + +def MUSA_WaitBarrierOp : MUSA_Op<"wait_barrier", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Wait on async barrier phase."; + let arguments = (ins I32:$barId, I32:$phaseId); + let assemblyFormat = + "$barId `,` $phaseId attr-dict `:` type($barId)"; + let hasVerifier = 1; +} + +def MUSA_AsyncTMECopyGlobalToLocalOp + : MUSA_Op<"async_tme_copy_global_to_local", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Asynchronously copy a TME tile from global to shared."; + let arguments = (ins + TME_DescType:$desc, + Variadic:$coord, + I32:$barId, + TTG_MemDescType:$result, + I1:$pred, + DenseI32ArrayAttr:$blockShape, + TME_SwizzleGranularityAttr:$swizzleGranularity, + TME_SwizzleStrideAttr:$swizzleStride, + TME_SwizzleLineAttr:$swizzleLine, + TME_PrefetchSizeAttr:$prefetchSize, + TME_CachePolicyAttr:$cachePolicy, + TME_PersistenceAttr:$innerPersistence, + TME_PersistenceAttr:$outerPersistence + ); + let assemblyFormat = [{ + $desc `[` $coord `]` `,` $barId `,` $result `,` $pred attr-dict + `:` qualified(type($desc)) `,` qualified(type($result)) + }]; + let hasVerifier = 1; +} + +def MUSA_AsyncTMECopyLocalToGlobalOp + : MUSA_Op<"async_tme_copy_local_to_global", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Asynchronously copy a TME tile from shared to global."; + let arguments = (ins + TME_DescType:$desc, + Variadic:$coord, + TTG_MemDescType:$src, + I1:$pred, + DenseI32ArrayAttr:$blockShape, + TME_SwizzleGranularityAttr:$swizzleGranularity, + TME_SwizzleStrideAttr:$swizzleStride, + TME_SwizzleLineAttr:$swizzleLine, + TME_CachePolicyAttr:$cachePolicy, + TME_PersistenceAttr:$innerPersistence, + TME_PersistenceAttr:$outerPersistence + ); + let assemblyFormat = [{ + $desc `[` $coord `]` `,` $src `,` $pred attr-dict + `:` qualified(type($desc)) `,` qualified(type($src)) + }]; + let hasVerifier = 1; +} + +def MUSA_TMEStoreCommitOp : MUSA_Op<"tme_store_commit", []> { + let summary = "Commit pending TME stores."; + let assemblyFormat = "attr-dict"; +} + +def MUSA_TMEStoreReadWaitOp : MUSA_Op<"tme_store_read_wait", []> { + let summary = "Wait until committed TME stores are visible to reads."; + let assemblyFormat = "attr-dict"; +} + +#endif diff --git a/third_party/mthreads/musa/include/MTGPUToLLVM/CMakeLists.txt b/third_party/mthreads/musa/include/MTGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..7691598c70 --- /dev/null +++ b/third_party/mthreads/musa/include/MTGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name MTGPUToLLVM) +add_public_tablegen_target(MTGPUConversionPassIncGen) diff --git a/third_party/mthreads/musa/include/MTGPUToLLVM/MTGPUToLLVMPass.h b/third_party/mthreads/musa/include/MTGPUToLLVM/MTGPUToLLVMPass.h new file mode 100644 index 0000000000..3a77047c3b --- /dev/null +++ b/third_party/mthreads/musa/include/MTGPUToLLVM/MTGPUToLLVMPass.h @@ -0,0 +1,40 @@ +#ifndef TRITON_CONVERSION_MTGPU_TO_LLVM_PASS_H +#define TRITON_CONVERSION_MTGPU_TO_LLVM_PASS_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +namespace mtgpu { + +void populateMTGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace mtgpu + +std::unique_ptr> createConvertMTGPUToLLVMPass(); +std::unique_ptr> +createConvertMTGPUToLLVMPass(int32_t computeCapability); + +#define GEN_PASS_DECL +#include "musa/include/MTGPUToLLVM/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "musa/include/MTGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_CONVERSION_MTGPU_TO_LLVM_PASS_H diff --git a/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.h b/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.h new file mode 100644 index 0000000000..330b23678d --- /dev/null +++ b/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.h @@ -0,0 +1,35 @@ +#ifndef MTGPU_CONVERSION_MTGPUTOLLVM_PASSES_H +#define MTGPU_CONVERSION_MTGPUTOLLVM_PASSES_H + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "musa/include/MTGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertMTGPUToLLVMPass(); +std::unique_ptr> +createConvertMTGPUToLLVMPass(int32_t computeCapability); + +#define GEN_PASS_REGISTRATION +#include "musa/include/MTGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.td b/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.td new file mode 100644 index 0000000000..1589439750 --- /dev/null +++ b/third_party/mthreads/musa/include/MTGPUToLLVM/Passes.td @@ -0,0 +1,24 @@ +#ifndef MTGPU_CONVERSION_PASSES +#define MTGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertMTGPUToLLVM : Pass<"convert-mtgpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert MTGPU dialect to LLVM for MUSA"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::cf::ControlFlowDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect", + "mlir::triton::mtgpu::MTGPUDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"31", + "device compute capability">, + ]; +} + +#endif diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/BarrierUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/BarrierUtils.h new file mode 100644 index 0000000000..d06c2715e2 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/BarrierUtils.h @@ -0,0 +1,114 @@ +#ifndef TRITONMUSA_COMMON_BARRIER_UTILS_H +#define TRITONMUSA_COMMON_BARRIER_UTILS_H + +#include "Dialect/MUSA/IR/Dialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir::triton::musa { + +inline constexpr llvm::StringLiteral kNextBarrierIdAttr = "musa.next_bar_id"; +inline constexpr llvm::StringLiteral kMaxBarrierIdAttr = "musa.max_bar_id"; +inline constexpr int32_t kMaxBarrierId = 63; + +inline ModuleOp getEnclosingModule(Operation *op) { + if (auto mod = dyn_cast(op)) + return mod; + return op->getParentOfType(); +} + +inline FunctionOpInterface getEnclosingFunction(Operation *op) { + while (op) { + if (auto func = dyn_cast(op)) + return func; + op = op->getParentOp(); + } + return nullptr; +} + +inline int32_t getImplicitAsyncBarrierFloor(FunctionOpInterface func) { + if (!func) + return 0; + + bool hasAsyncCommitGroup = false; + func.walk( + [&](triton::gpu::AsyncCommitGroupOp) { hasAsyncCommitGroup = true; }); + return hasAsyncCommitGroup ? 1 : 0; +} + +inline FailureOr reserveBarrierIdRange(Operation *anchorOp, + int32_t numSlots) { + if (numSlots <= 0 || numSlots > kMaxBarrierId) + return failure(); + + auto func = getEnclosingFunction(anchorOp); + if (!func) + return failure(); + + auto *ctx = func->getContext(); + auto i32Ty = IntegerType::get(ctx, 32); + auto nextAttr = func->getAttrOfType(kNextBarrierIdAttr); + auto maxAttr = func->getAttrOfType(kMaxBarrierIdAttr); + int32_t current = nextAttr ? static_cast(nextAttr.getInt()) : 0; + if (!nextAttr && maxAttr) + current = static_cast(maxAttr.getInt()); + current = std::max(current, getImplicitAsyncBarrierFloor(func)); + int32_t base = current + 1; + int32_t next = current + numSlots; + if (next > kMaxBarrierId) + return failure(); + + func->setAttr(kNextBarrierIdAttr, IntegerAttr::get(i32Ty, next)); + int32_t funcMax = maxAttr ? static_cast(maxAttr.getInt()) : 0; + if (next > funcMax) + func->setAttr(kMaxBarrierIdAttr, IntegerAttr::get(i32Ty, next)); + return base; +} + +inline FailureOr reserveFreshBarrierId(Operation *anchorOp) { + return reserveBarrierIdRange(anchorOp, /*numSlots=*/1); +} + +inline int32_t getReservedBarrierCount(FunctionOpInterface func) { + auto attr = func->getAttrOfType(kMaxBarrierIdAttr); + return attr ? static_cast(attr.getInt()) : 0; +} + +inline void finalizeBarRecord(FunctionOpInterface func, + RewriterBase &rewriter) { + SmallVector records; + func->walk([&](BarRecordOp op) { records.push_back(op); }); + + int32_t barCount = std::max(getReservedBarrierCount(func), + getImplicitAsyncBarrierFloor(func)); + if (barCount <= 0) { + for (BarRecordOp record : records) + record.erase(); + func->removeAttr(kNextBarrierIdAttr); + func->removeAttr(kMaxBarrierIdAttr); + return; + } + + auto loc = func.getLoc(); + auto i32Ty = IntegerType::get(func->getContext(), 32); + func->setAttr(kMaxBarrierIdAttr, IntegerAttr::get(i32Ty, barCount)); + func->removeAttr(kNextBarrierIdAttr); + + for (BarRecordOp record : records) + record.erase(); + + OpBuilder::InsertionGuard guard(rewriter); + Block &entry = func.getFunctionBody().front(); + rewriter.setInsertionPointToStart(&entry); + auto count = arith::ConstantIntOp::create(rewriter, loc, barCount, 32); + BarRecordOp::create(rewriter, loc, count); +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_BARRIER_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/MMAContractUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/MMAContractUtils.h new file mode 100644 index 0000000000..149eff1200 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/MMAContractUtils.h @@ -0,0 +1,376 @@ +#ifndef TRITONMUSA_COMMON_MMA_CONTRACT_UTILS_H +#define TRITONMUSA_COMMON_MMA_CONTRACT_UTILS_H + +#include "Dialect/MUSA/IR/Dialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace mlir::triton::musa { + +inline bool isFloat8E4M3(Type ty) { + return llvm::isa(ty); +} + +inline bool isFloat8E5M2(Type ty) { + return llvm::isa(ty); +} + +inline std::optional getWmmaEltType(Type elemTy) { + if (elemTy.isInteger(8)) + return SQMMAEltType::s8; + if (elemTy.isF16()) + return SQMMAEltType::f16; + if (elemTy.isBF16()) + return SQMMAEltType::bf16; + if (elemTy.isF32()) + return SQMMAEltType::tf32; + if (isFloat8E4M3(elemTy)) + return SQMMAEltType::e4m3; + if (isFloat8E5M2(elemTy)) + return SQMMAEltType::e5m2; + return std::nullopt; +} + +struct WmmaIntrinsicSignature { + SQMMAEltType eltType; + unsigned m; + unsigned n; + unsigned k; + const char *name; + int32_t sat; +}; + +inline constexpr WmmaIntrinsicSignature kWmmaIntrinsics[] = { + {SQMMAEltType::s8, 8, 16, 16, "llvm.musa.imma.m8n16k16.mma", 1}, + {SQMMAEltType::s8, 16, 8, 16, "llvm.musa.imma.m16n8k16.mma", 1}, + {SQMMAEltType::s8, 16, 16, 16, "llvm.musa.imma.m16n16k16.mma", 1}, + {SQMMAEltType::s8, 16, 16, 32, "llvm.musa.imma.m16n16k32.mma", 1}, + {SQMMAEltType::s8, 16, 16, 64, "llvm.musa.imma.m16n16k64.mma", 1}, + {SQMMAEltType::f16, 16, 8, 8, "llvm.musa.ffmma.m16n8k8.mma", 1}, + {SQMMAEltType::f16, 16, 8, 16, "llvm.musa.ffmma.m16n8k16.mma", 1}, + {SQMMAEltType::f16, 8, 16, 16, "llvm.musa.ffmma.m8n16k16.mma", 1}, + {SQMMAEltType::f16, 16, 16, 16, "llvm.musa.ffmma.m16n16k16.mma", 1}, + {SQMMAEltType::f16, 16, 16, 32, "llvm.musa.ffmma.m16n16k32.mma", 1}, + {SQMMAEltType::bf16, 16, 8, 8, "llvm.musa.bfmma.m16n8k8.mma", 1}, + {SQMMAEltType::bf16, 16, 8, 16, "llvm.musa.bfmma.m16n8k16.mma", 1}, + {SQMMAEltType::bf16, 8, 16, 16, "llvm.musa.bfmma.m8n16k16.mma", 1}, + {SQMMAEltType::bf16, 16, 16, 16, "llvm.musa.bfmma.m16n16k16.mma", 1}, + {SQMMAEltType::bf16, 16, 16, 32, "llvm.musa.bfmma.m16n16k32.mma", 1}, + {SQMMAEltType::tf32, 16, 8, 4, "llvm.musa.tfmma.m16n8k4.mma", 1}, + {SQMMAEltType::tf32, 16, 8, 8, "llvm.musa.tfmma.m16n8k8.mma", 1}, + {SQMMAEltType::tf32, 16, 16, 16, "llvm.musa.tfmma.m16n16k16.mma", 1}, + {SQMMAEltType::e4m3, 8, 16, 16, "llvm.musa.e4m3.m8n16k16.mma", 1}, + {SQMMAEltType::e4m3, 16, 8, 16, "llvm.musa.e4m3.m16n8k16.mma", 1}, + {SQMMAEltType::e4m3, 16, 16, 16, "llvm.musa.e4m3.m16n16k16.mma", 1}, + {SQMMAEltType::e4m3, 16, 16, 32, "llvm.musa.e4m3.m16n16k32.mma", 1}, + {SQMMAEltType::e4m3, 16, 16, 64, "llvm.musa.e4m3.m16n16k64.mma", 1}, + {SQMMAEltType::e5m2, 8, 16, 16, "llvm.musa.e5m2.m8n16k16.mma", 1}, + {SQMMAEltType::e5m2, 16, 8, 16, "llvm.musa.e5m2.m16n8k16.mma", 1}, + {SQMMAEltType::e5m2, 16, 16, 16, "llvm.musa.e5m2.m16n16k16.mma", 1}, + {SQMMAEltType::e5m2, 16, 16, 32, "llvm.musa.e5m2.m16n16k32.mma", 1}, + {SQMMAEltType::e5m2, 16, 16, 64, "llvm.musa.e5m2.m16n16k64.mma", 1}, +}; + +inline std::optional +lookupWmmaIntrinsicName(SQMMAEltType eltType, unsigned m, unsigned n, + unsigned k) { + for (const auto &def : kWmmaIntrinsics) { + if (def.eltType == eltType && def.m == m && def.n == n && def.k == k) + return llvm::StringRef(def.name); + } + return std::nullopt; +} + +inline std::optional +lookupWmmaIntrinsic(Type elemTy, ArrayRef instrShape) { + if (instrShape.size() != 3) + return std::nullopt; + auto eltType = getWmmaEltType(elemTy); + if (!eltType) + return std::nullopt; + for (const auto &def : kWmmaIntrinsics) { + if (def.eltType == *eltType && def.m == instrShape[0] && + def.n == instrShape[1] && def.k == instrShape[2]) + return def; + } + return std::nullopt; +} + +inline int32_t encodeWmmaShape(SQMMALayout layoutA, SQMMALayout layoutB) { + return (layoutA == SQMMALayout::col ? 2 : 0) | + (layoutB == SQMMALayout::col ? 1 : 0); +} + +inline int32_t getWmmaFmt(Type) { return 1; } +inline int32_t getWmmaFmt(SQMMAEltType) { return 1; } + +inline bool needsWmmaScaleOperands(SQMMAEltType eltType) { + switch (eltType) { + case SQMMAEltType::e4m3: + case SQMMAEltType::e5m2: + return true; + default: + return false; + } +} + +inline SQMMALayout flipWmmaLayout(SQMMALayout layout) { + return layout == SQMMALayout::row ? SQMMALayout::col : SQMMALayout::row; +} + +inline SQMMALayout getDefaultWmmaFragmentLayout(unsigned opIdx) { + assert(opIdx < 2 && "WMMA operand index must be 0 or 1"); + return opIdx == 0 ? SQMMALayout::row : SQMMALayout::col; +} + +inline SQMMALayout inferWmmaFragmentLayout(Value value, unsigned opIdx) { + while (auto cvt = value.getDefiningOp()) + value = cvt.getSrc(); + while (auto bitcast = value.getDefiningOp()) + value = bitcast.getSrc(); + if (auto trans = value.getDefiningOp()) + return flipWmmaLayout(inferWmmaFragmentLayout(trans.getSrc(), opIdx)); + return getDefaultWmmaFragmentLayout(opIdx); +} + +struct WmmaDotOperandContract { + gpu::DotOperandEncodingAttr dotEncoding; + LinearLayout linearLayout; + unsigned opIdx; + unsigned rank; +}; + +inline FailureOr +resolveWmmaDotOperandContract(RankedTensorType tensorTy, + unsigned expectedOpIdx) { + auto dotEncoding = + dyn_cast(tensorTy.getEncoding()); + if (!dotEncoding) + return failure(); + if (dotEncoding.getOpIdx() != expectedOpIdx) + return failure(); + return WmmaDotOperandContract{ + dotEncoding, dotEncoding.toLinearLayout(tensorTy.getShape()), + expectedOpIdx, static_cast(tensorTy.getRank())}; +} + +inline Value extractWmmaOperandVectorFromContract( + Location loc, Value operand, const WmmaDotOperandContract &contract, + const LLVMTypeConverter *typeConverter, RewriterBase &rewriter, int batch, + int nonK, int kIdx, int kInst, int kBase, int kPadding, Type elemTy) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto elems = unpackLLElements(loc, operand, rewriter); + + auto outDimNames = contract.linearLayout.getOutDimNames(); + auto *ctx = (*outDimNames.begin()).getContext(); + StringAttr dim0 = StringAttr::get(ctx, "dim0"); + StringAttr dim1 = StringAttr::get(ctx, "dim1"); + StringAttr dim2 = StringAttr::get(ctx, "dim2"); + + const int kElemIdx = kIdx * kInst; + const int mCoord = (contract.opIdx == 0) ? nonK : kElemIdx; + const int nCoord = (contract.opIdx == 0) ? kElemIdx : nonK; + + SmallVector> outCoords; + if (contract.rank == 3) + outCoords = {{dim0, batch}, {dim1, mCoord}, {dim2, nCoord}}; + else + outCoords = {{dim0, mCoord}, {dim1, nCoord}}; + + auto inDims = contract.linearLayout.pseudoinvert().apply(outCoords); + const int startReg = inDims[0].second; + + Type llvmElemTy = typeConverter->convertType(elemTy); + Type vecTy = vec_ty(llvmElemTy, kBase); + Value vec = b.undef(vecTy); + const int validK = kBase - kPadding; + Value zero = LLVM::ZeroOp::create(rewriter, loc, llvmElemTy); + for (int k = 0; k < kBase; ++k) { + Value v = (k < validK) ? elems[startReg + k] : zero; + vec = b.insert_element(vecTy, vec, v, b.i32_val(k)); + } + return vec; +} + +inline SmallVector +buildWmmaIntrinsicArgs(Location loc, Value opA, Value opB, Value opC, + SQMMALayout layoutA, SQMMALayout layoutB, + const WmmaIntrinsicSignature &signature, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector args = {opA, + opB, + opC, + b.i32_val(signature.sat), + b.i32_val(getWmmaFmt(signature.eltType)), + b.i32_val(encodeWmmaShape(layoutA, layoutB))}; + if (needsWmmaScaleOperands(signature.eltType)) { + args.push_back(b.i32_val(0)); + args.push_back(b.i32_val(0)); + args.push_back(b.i32_val(0)); + } + return args; +} + +inline SmallVector +buildWmmaIntrinsicArgs(Location loc, Value opA, Value opB, Value opC, + SQMMALayout layoutA, SQMMALayout layoutB, + SQMMAEltType eltType, RewriterBase &rewriter) { + return buildWmmaIntrinsicArgs( + loc, opA, opB, opC, layoutA, layoutB, + WmmaIntrinsicSignature{eltType, /*m=*/0, /*n=*/0, /*k=*/0, + /*name=*/nullptr, /*sat=*/1}, + rewriter); +} + +inline SmallVector +buildWmmaIntrinsicArgs(Location loc, Value opA, Value opB, Value opC, + SQMMALayout layoutA, SQMMALayout layoutB, Type elemTy, + RewriterBase &rewriter) { + auto eltType = getWmmaEltType(elemTy); + assert(eltType && "WMMA element type must be validated before arg building"); + return buildWmmaIntrinsicArgs(loc, opA, opB, opC, layoutA, layoutB, *eltType, + rewriter); +} + +inline std::string getSqmmaTypeTag(SQMMAEltType type) { + switch (type) { + case SQMMAEltType::f16: + return "fmma"; + case SQMMAEltType::bf16: + return "bfmma"; + case SQMMAEltType::tf32: + return "tfmma"; + case SQMMAEltType::s8: + return "smma"; + case SQMMAEltType::e4m3: + return "e4m3"; + case SQMMAEltType::e5m2: + return "e5m2"; + default: + return ""; + } +} + +inline bool isSupportedSqmmaInstrMN(unsigned m, unsigned n) { + static constexpr std::pair kAllowedMN[] = { + {32, 32}, {32, 64}, {32, 128}, {16, 64}, {64, 16}, {64, 32}, + {64, 64}, {64, 128}, {128, 32}, {128, 64}, {128, 128}, + }; + for (const auto &[supportedM, supportedN] : kAllowedMN) { + if (supportedM == m && supportedN == n) + return true; + } + return false; +} + +inline bool isSupportedSqmmaInstrMN(SQMMAEltType eltType, unsigned m, + unsigned n) { + switch (eltType) { + case SQMMAEltType::f16: + case SQMMAEltType::bf16: + case SQMMAEltType::s8: + case SQMMAEltType::e4m3: + case SQMMAEltType::e5m2: + return isSupportedSqmmaInstrMN(m, n); + case SQMMAEltType::tf32: { + static constexpr std::pair kAllowedTf32MN[] = { + {16, 64}, {32, 32}, {32, 64}, {64, 16}, + {64, 32}, {64, 64}, {128, 64}, {128, 128}, + }; + for (const auto &[supportedM, supportedN] : kAllowedTf32MN) { + if (supportedM == m && supportedN == n) + return true; + } + return false; + } + default: + return false; + } +} + +inline bool isSupportedSqmma(SQMMAEltType eltTypeA, SQMMAEltType eltTypeB, + SQMMAEltType eltTypeC, unsigned m, unsigned n, + unsigned k) { + if (m == 0 || n == 0 || k == 0 || (m % 8) || (n % 8) || (k % 8)) + return false; + if (eltTypeA != eltTypeB) + return false; + if (!isSupportedSqmmaInstrMN(eltTypeA, m, n)) + return false; + + auto isValidPh1K = [&](SQMMAEltType type, unsigned mVal, unsigned nVal, + unsigned kVal) { + switch (type) { + case SQMMAEltType::f16: + case SQMMAEltType::bf16: + if (kVal == 16 || kVal == 32 || kVal == 64) + return true; + return kVal == 128 && + ((mVal == 16 && nVal == 64) || (mVal == 64 && nVal == 16)); + case SQMMAEltType::tf32: + return kVal == 8 || kVal == 16 || kVal == 32; + case SQMMAEltType::s8: + case SQMMAEltType::e4m3: + case SQMMAEltType::e5m2: + return kVal == 32 || kVal == 64 || kVal == 128; + default: + return false; + } + }; + + switch (eltTypeA) { + case SQMMAEltType::f16: + case SQMMAEltType::bf16: + case SQMMAEltType::tf32: + return eltTypeC == SQMMAEltType::f32 && isValidPh1K(eltTypeA, m, n, k); + case SQMMAEltType::s8: + return eltTypeC == SQMMAEltType::s32 && isValidPh1K(eltTypeA, m, n, k); + case SQMMAEltType::e4m3: + case SQMMAEltType::e5m2: + return eltTypeC == SQMMAEltType::f32 && isValidPh1K(eltTypeA, m, n, k); + default: + return false; + } +} + +inline std::string lookupSqmmaIntrinsic(SQMMAEltType type, unsigned m, + unsigned n, unsigned k) { + auto tag = getSqmmaTypeTag(type); + if (tag.empty()) + return ""; + return ("llvm.musa.sqmma." + tag + ".m" + std::to_string(m) + "n" + + std::to_string(n) + "k" + std::to_string(k) + ".mma"); +} + +inline Value materializeUseCFlag(Location loc, Value useC, + RewriterBase &rewriter) { + if (useC) + return useC; + return arith::ConstantIntOp::create(rewriter, loc, 1, 1); +} + +inline Value selectAccumulatorValue(Location loc, Value useC, Value acc, + Value zero, RewriterBase &rewriter) { + auto constUseC = ::mlir::triton::getBoolFromConstant(useC); + if (constUseC && *constUseC) + return acc; + if (constUseC && !*constUseC) + return zero; + return LLVM::SelectOp::create(rewriter, loc, acc.getType(), useC, acc, zero); +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_MMA_CONTRACT_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/MMAEncodingUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/MMAEncodingUtils.h new file mode 100644 index 0000000000..42eee4a57e --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/MMAEncodingUtils.h @@ -0,0 +1,24 @@ +#ifndef TRITONMUSA_COMMON_MMA_ENCODING_UTILS_H +#define TRITONMUSA_COMMON_MMA_ENCODING_UTILS_H + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::musa { +namespace ttg = mlir::triton::gpu; + +inline constexpr unsigned kMusaPH1VersionMajor = 3; + +// Keep backend support policy out of the public TTG verifier. The TTG encoding +// carries version metadata, while the MUSA backend decides which generations it +// can lower. +inline bool supportsMusaWmmaEncoding(ttg::MUSAWmmaEncodingAttr encoding) { + return encoding && encoding.getVersionMajor() == kMusaPH1VersionMajor; +} + +inline bool supportsMusaSqmmaEncoding(ttg::MUSASqmmaEncodingAttr encoding) { + return encoding && encoding.getVersionMajor() == kMusaPH1VersionMajor; +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_MMA_ENCODING_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/MMAOperandUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/MMAOperandUtils.h new file mode 100644 index 0000000000..ff026bba79 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/MMAOperandUtils.h @@ -0,0 +1,529 @@ +#ifndef TRITONMUSA_COMMON_MMA_OPERAND_UTILS_H +#define TRITONMUSA_COMMON_MMA_OPERAND_UTILS_H + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include + +namespace mlir::triton::musa { +namespace ttg = mlir::triton::gpu; + +inline SmallVector +getMemDescPhysicalShape(ttg::MemDescType memDescTy) { + auto allocShape = memDescTy.getAllocShape(); + if (allocShape.empty()) + return SmallVector(memDescTy.getShape().begin(), + memDescTy.getShape().end()); + if (allocShape.size() >= static_cast(memDescTy.getRank())) + allocShape = allocShape.take_back(memDescTy.getRank()); + return SmallVector(allocShape.begin(), allocShape.end()); +} + +inline SmallVector getSharedOrder(Attribute encoding, + ArrayRef shape = {}) { + if (auto swizzled = + dyn_cast_or_null(encoding)) + return SmallVector(swizzled.getOrder().begin(), + swizzled.getOrder().end()); + if (auto padded = dyn_cast_or_null(encoding)) + return SmallVector(padded.getOrder().begin(), + padded.getOrder().end()); + if (auto shared = dyn_cast_or_null(encoding)) { + if (!shape.empty()) + return ttg::getOrder(shared, shape); + } + return {}; +} + +inline bool isSharedEncoding(Attribute encoding) { + return isa(encoding); +} + +inline bool inferSharedRowMajor(ttg::MemDescType memDescTy, + ArrayRef physicalShape = {}) { + SmallVector shape; + if (physicalShape.empty()) + shape = getMemDescPhysicalShape(memDescTy); + else + shape.assign(physicalShape.begin(), physicalShape.end()); + auto order = getSharedOrder(memDescTy.getEncoding(), shape); + return !order.empty() && order.front() + 1 == shape.size(); +} + +inline bool isMemDescViewLikeOp(Operation *op) { + return isa(op); +} + +inline bool isMemDescSqmmaContractBridgeOp(Operation *op) { + return isMemDescViewLikeOp(op) || isa(op); +} + +inline bool isTMEBackedMemDesc(Value memDesc) { + llvm::SmallVector worklist; + llvm::SmallPtrSet visited; + auto enqueue = [&](Value candidate) { + if (!candidate || !isa(candidate.getType())) + return; + if (visited.insert(candidate.getAsOpaquePointer()).second) + worklist.push_back(candidate); + }; + + enqueue(memDesc); + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + + if (Operation *defOp = current.getDefiningOp()) { + if (auto indexOp = dyn_cast(defOp)) + enqueue(indexOp.getSrc()); + else if (auto subsliceOp = dyn_cast(defOp)) + enqueue(subsliceOp.getSrc()); + else if (auto reinterpretOp = dyn_cast(defOp)) + enqueue(reinterpretOp.getSrc()); + else if (auto transOp = dyn_cast(defOp)) + enqueue(transOp.getSrc()); + else if (auto reshapeOp = dyn_cast(defOp)) + enqueue(reshapeOp.getSrc()); + } + + for (Operation *user : current.getUsers()) { + if (isa(user)) + return true; + if (!isMemDescViewLikeOp(user)) + continue; + for (Value result : user->getResults()) + enqueue(result); + } + } + + return false; +} + +inline bool needsSqmmaIssueBarrier(Value aMemDesc, Value bMemDesc) { + return !(isTMEBackedMemDesc(aMemDesc) && isTMEBackedMemDesc(bMemDesc)); +} + +inline std::optional +composeMusaOperandSharedLayout(ttg::DotOperandEncodingAttr dotEncoding, + ArrayRef operandShape, + ArrayRef sharedOrder, + ttg::CGAEncodingAttr cgaLayout, + unsigned elemBitWidth, bool needTrans) { + if (!dotEncoding || elemBitWidth == 0) + return std::nullopt; + + Attribute parent = dotEncoding.getParent(); + if (auto wmma = dyn_cast(parent)) { + return wmma.composeSharedLayoutForOperand( + cgaLayout, dotEncoding.getOpIdx(), operandShape, sharedOrder, + dotEncoding.getKWidth(), elemBitWidth, needTrans); + } + if (auto sqmma = dyn_cast(parent)) { + return sqmma.composeSharedLayoutForOperand( + cgaLayout, dotEncoding.getOpIdx(), operandShape, sharedOrder, + dotEncoding.getKWidth(), elemBitWidth, needTrans); + } + return std::nullopt; +} + +inline std::optional +composeMusaOperandSharedLayout(ttg::DotOperandEncodingAttr dotEncoding, + ArrayRef operandShape, + ArrayRef sharedOrder, + ttg::CGAEncodingAttr cgaLayout, Type elementType, + bool needTrans) { + unsigned elemBitWidth = elementType.getIntOrFloatBitWidth(); + return composeMusaOperandSharedLayout(dotEncoding, operandShape, sharedOrder, + cgaLayout, elemBitWidth, needTrans); +} + +inline std::optional inferElemBytesFromMemDesc(ttg::MemDescType type) { + int bitWidth = type.getElementTypeBitWidth(); + if (bitWidth <= 0) + return std::nullopt; + return static_cast((bitWidth + 7) / 8); +} + +struct RecoveredSqmmaConsumerContract { + int64_t sqmmaOpIdx = -1; + int64_t elemBytes = 0; + bool rowMajor = true; + + bool operator==(const RecoveredSqmmaConsumerContract &other) const { + return sqmmaOpIdx == other.sqmmaOpIdx && elemBytes == other.elemBytes && + rowMajor == other.rowMajor; + } +}; + +inline FailureOr> +getSqmmaContractFromAnnotatedOp(Operation *op, bool defaultRowMajor) { + if (!op || !hasSqmmaOpIdxAttr(op)) + return std::optional{}; + + auto opIdx = getSqmmaOpIdx(op); + auto elemBytes = getSqmmaElemBytes(op); + if (!opIdx || !elemBytes || *elemBytes <= 0) + return failure(); + + return std::optional( + RecoveredSqmmaConsumerContract{*opIdx, *elemBytes, + getSqmmaRowMajor(op, defaultRowMajor)}); +} + +inline FailureOr> +recoverSqmmaProducerContractFromMemDesc(Value memDesc) { + auto memDescTy = dyn_cast(memDesc.getType()); + if (!memDescTy) + return failure(); + + std::optional contract; + llvm::SmallVector worklist{memDesc}; + llvm::SmallPtrSet visited; + + auto mergeCandidate = + [&](std::optional candidate) + -> LogicalResult { + if (!candidate) + return success(); + if (contract && !(*contract == *candidate)) + return failure(); + contract = *candidate; + return success(); + }; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current.getAsOpaquePointer()).second) + continue; + auto currentTy = dyn_cast(current.getType()); + if (!currentTy) + return failure(); + + Operation *defOp = current.getDefiningOp(); + auto currentContract = + getSqmmaContractFromAnnotatedOp(defOp, inferSharedRowMajor(currentTy)); + if (failed(currentContract) || failed(mergeCandidate(*currentContract))) + return failure(); + if (!defOp) + continue; + + auto enqueueOperand = [&](Value operand) { + if (isa(operand.getType())) + worklist.push_back(operand); + }; + + if (isMemDescSqmmaContractBridgeOp(defOp)) { + for (Value operand : defOp->getOperands()) + enqueueOperand(operand); + continue; + } + + if (auto waitOp = dyn_cast(defOp)) { + for (auto [idx, result] : llvm::enumerate(waitOp->getResults())) { + if (result != current) + continue; + enqueueOperand(waitOp.getInputs()[idx]); + } + continue; + } + if (auto waitOp = dyn_cast(defOp)) { + for (auto [idx, result] : llvm::enumerate(waitOp->getResults())) { + if (result != current) + continue; + enqueueOperand(waitOp.getInputs()[idx]); + } + continue; + } + } + + return contract; +} + +inline FailureOr> +recoverUniqueSqmmaConsumerContract(Value memDesc); + +inline bool isTensorSqmmaContractBridgeOp(Operation *op) { + return isa(op); +} + +inline FailureOr> +recoverUniqueSqmmaConsumerContractFromTensor(Value tensor) { + if (!isa(tensor.getType())) + return failure(); + + std::optional contract; + bool sawNonSqmmaTerminal = false; + llvm::SmallVector worklist; + llvm::SmallPtrSet visited; + worklist.push_back(tensor); + + auto mergeCandidate = + [&](std::optional candidate) + -> LogicalResult { + if (!candidate) { + sawNonSqmmaTerminal = true; + return success(); + } + if (contract && !(*contract == candidate)) + return failure(); + contract = *candidate; + return success(); + }; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current.getAsOpaquePointer()).second) + continue; + if (!isa(current.getType())) + return failure(); + + for (Operation *user : current.getUsers()) { + if (isTensorSqmmaContractBridgeOp(user)) { + for (Value result : user->getResults()) + if (isa(result.getType())) + worklist.push_back(result); + continue; + } + + if (auto localAlloc = dyn_cast(user)) { + auto nestedContract = + recoverUniqueSqmmaConsumerContract(localAlloc.getResult()); + if (failed(nestedContract)) + return failure(); + if (failed(mergeCandidate(*nestedContract))) + return failure(); + continue; + } + + sawNonSqmmaTerminal = true; + } + } + + if (contract && sawNonSqmmaTerminal) + return failure(); + return contract; +} + +inline FailureOr> +recoverUniqueSqmmaConsumerContract(Value memDesc) { + if (!isa(memDesc.getType())) + return failure(); + + std::optional contract; + bool sawNonSqmmaTerminal = false; + llvm::SmallVector worklist; + llvm::SmallPtrSet visited; + worklist.push_back(memDesc); + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current.getAsOpaquePointer()).second) + continue; + auto currentTy = dyn_cast(current.getType()); + if (!currentTy) + return failure(); + auto currentContract = recoverSqmmaProducerContractFromMemDesc(current); + if (failed(currentContract)) + return failure(); + + for (Operation *user : current.getUsers()) { + if (isa(user)) + continue; + + if (isMemDescSqmmaContractBridgeOp(user)) { + for (Value result : user->getResults()) + if (isa(result.getType())) + worklist.push_back(result); + continue; + } + + if (auto waitOp = dyn_cast(user)) { + for (auto [idx, operand] : llvm::enumerate(waitOp.getInputs())) { + if (operand != current) + continue; + Value passthrough = waitOp.getResult(idx); + if (isa(passthrough.getType())) + worklist.push_back(passthrough); + } + continue; + } + if (auto waitOp = dyn_cast(user)) { + for (auto [idx, operand] : llvm::enumerate(waitOp.getInputs())) { + if (operand != current) + continue; + Value passthrough = waitOp.getResult(idx); + if (isa(passthrough.getType())) + worklist.push_back(passthrough); + } + continue; + } + + if (auto localLoad = dyn_cast(user)) { + auto tensorContract = + recoverUniqueSqmmaConsumerContractFromTensor(localLoad.getResult()); + if (failed(tensorContract)) + return failure(); + if (*tensorContract) { + if (contract && !(*contract == *tensorContract)) + return failure(); + contract = **tensorContract; + } else { + sawNonSqmmaTerminal = true; + } + continue; + } + + auto userContract = + getSqmmaContractFromAnnotatedOp(user, inferSharedRowMajor(currentTy)); + if (failed(userContract)) + return failure(); + + std::optional candidate = *userContract; + if ((isa(user)) && + !candidate) + candidate = *currentContract; + + if (!candidate) { + if (isa(user)) + return failure(); + sawNonSqmmaTerminal = true; + continue; + } + + if (contract && !(*contract == candidate)) + return failure(); + contract = *candidate; + } + } + + if (contract && sawNonSqmmaTerminal) + return failure(); + return contract; +} + +inline LogicalResult verifyGroupedTMELoadConsumerContract( + AsyncTMECopyGlobalToLocalOp op, + std::optional contract) { + if (!contract) + return success(); + + auto memDescTy = dyn_cast(op.getResult().getType()); + if (!memDescTy || memDescTy.getShape().size() != 2) { + return op.emitOpError("grouped SQMMA TME load requires a 2D shared " + "memdesc"); + } + + auto order = getSharedOrder(memDescTy.getEncoding(), memDescTy.getShape()); + if (order.empty()) { + return op.emitOpError("grouped SQMMA TME load requires a valid shared " + "order"); + } + + auto maybeElemBytes = inferElemBytesFromMemDesc(memDescTy); + if (!maybeElemBytes || *maybeElemBytes <= 0 || + *maybeElemBytes != contract->elemBytes) { + return op.emitOpError("grouped SQMMA TME load recovered element size does " + "not match destination memdesc"); + } + + int64_t maxLeadingElems = 256 / *maybeElemBytes; + if (maxLeadingElems <= 0) { + return op.emitOpError("grouped SQMMA TME load has invalid leading segment " + "size"); + } + + return success(); +} + +inline FailureOr> +recoverAndVerifyGroupedTMELoadConsumerContract(AsyncTMECopyGlobalToLocalOp op) { + auto producerContract = + recoverSqmmaProducerContractFromMemDesc(op.getResult()); + if (failed(producerContract)) { + op.emitOpError("grouped TME load segment requires a valid SQMMA " + "producer contract on the destination memdesc"); + return failure(); + } + + auto consumerContract = recoverUniqueSqmmaConsumerContract(op.getResult()); + if (failed(consumerContract)) { + op.emitOpError("grouped TME load segment requires a unique consistent " + "SQMMA consumer contract when high-level consumers are " + "still present"); + return failure(); + } + + std::optional contract = *producerContract; + if (*consumerContract) { + if (contract && !(*contract == **consumerContract)) { + op.emitOpError("grouped TME load producer contract does not match the " + "unique recovered SQMMA consumer contract"); + return failure(); + } + contract = *consumerContract; + } + + if (failed(verifyGroupedTMELoadConsumerContract(op, contract))) + return failure(); + return contract; +} + +struct ResolvedSharedOperand { + Value memDesc; + Value llvmMemDesc; + ttg::MemDescType memDescTy; + LLVM::SharedMemoryObject sharedMemObj; + Value affineBase; + SmallVector physicalShape; +}; + +inline FailureOr +resolveSharedOperandWithAffineBase(Value operand, Value adaptorOperand, + Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + Value memDesc = operand; + Value llvmMemDesc = adaptorOperand; + auto memDescTy = dyn_cast(operand.getType()); + + if (!memDescTy) { + Value source = operand; + while (auto cvt = source.getDefiningOp()) + source = cvt.getSrc(); + + auto localLoad = source.getDefiningOp(); + if (!localLoad) + return failure(); + + memDesc = localLoad.getSrc(); + memDescTy = dyn_cast(memDesc.getType()); + if (!memDescTy) + return failure(); + llvmMemDesc = rewriter.getRemappedValue(memDesc); + if (!llvmMemDesc) + return failure(); + } + + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto sharedMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, llvmMemDesc, llvmElemTy, rewriter); + Value affineBase = sharedMemObj.getShmemAffineBase(loc, rewriter, memDescTy); + return ResolvedSharedOperand{memDesc, llvmMemDesc, + memDescTy, sharedMemObj, + affineBase, getMemDescPhysicalShape(memDescTy)}; +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_MMA_OPERAND_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/MatmulPolicy.h b/third_party/mthreads/musa/include/TritonMUSACommon/MatmulPolicy.h new file mode 100644 index 0000000000..882d7fce33 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/MatmulPolicy.h @@ -0,0 +1,4 @@ +#ifndef TRITONMUSA_COMMON_MATMUL_POLICY_H +#define TRITONMUSA_COMMON_MATMUL_POLICY_H + +#endif // TRITONMUSA_COMMON_MATMUL_POLICY_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/MemDescUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/MemDescUtils.h new file mode 100644 index 0000000000..0f5e16028b --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/MemDescUtils.h @@ -0,0 +1,479 @@ +#ifndef TRITONMUSA_COMMON_MEMDESC_UTILS_H +#define TRITONMUSA_COMMON_MEMDESC_UTILS_H + +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace mlir::triton::musa { +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +inline bool areMemDescTypesCompatible(ttg::MemDescType lhs, + ttg::MemDescType rhs) { + return lhs.getShape() == rhs.getShape() && + lhs.getElementType() == rhs.getElementType() && + lhs.getEncoding() == rhs.getEncoding() && + lhs.getMemorySpace() == rhs.getMemorySpace() && + lhs.getAllocShape() == rhs.getAllocShape(); +} + +inline bool +areMemDescTypesEquivalentForDescriptorContract(ttg::MemDescType lhs, + ttg::MemDescType rhs) { + return lhs.getShape() == rhs.getShape() && + lhs.getElementType() == rhs.getElementType() && + lhs.getEncoding() == rhs.getEncoding() && + lhs.getMemorySpace() == rhs.getMemorySpace(); +} + +inline bool areMemDescTypesLayoutEquivalent(ttg::MemDescType lhs, + ttg::MemDescType rhs) { + return lhs.getShape() == rhs.getShape() && + lhs.getElementType() == rhs.getElementType() && + lhs.getMemorySpace() == rhs.getMemorySpace() && + lhs.getAllocShape() == rhs.getAllocShape() && + getMUSASharedLinearLayoutOrGeneric(lhs) == + getMUSASharedLinearLayoutOrGeneric(rhs); +} + +inline bool isCanonicalLandingMemDescViewOp(Operation *op) { + return isa(op); +} + +inline SmallVector +collectCanonicalLandingLocalAllocUsers(tt::DescriptorLoadOp loadOp) { + SmallVector localAllocs; + llvm::SmallPtrSet seen; + for (Operation *user : loadOp.getResult().getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc || localAlloc.getSrc() != loadOp.getResult()) + continue; + if (seen.insert(localAlloc.getOperation()).second) + localAllocs.push_back(localAlloc); + } + return localAllocs; +} + +inline SmallVector +collectCanonicalLandingMemDescRoots(tt::DescriptorLoadOp loadOp) { + SmallVector roots; + llvm::SmallPtrSet seen; + for (Operation *user : loadOp.getResult().getUsers()) { + if (auto localAlloc = dyn_cast(user)) { + if (localAlloc.getSrc() != loadOp.getResult()) + continue; + if (seen.insert(localAlloc.getResult().getAsOpaquePointer()).second) + roots.push_back(localAlloc.getResult()); + continue; + } + auto localStore = dyn_cast(user); + if (!localStore || localStore.getSrc() != loadOp.getResult()) + continue; + if (seen.insert(localStore.getDst().getAsOpaquePointer()).second) + roots.push_back(localStore.getDst()); + } + return roots; +} + +inline SmallVector +collectCanonicalLandingTerminalMemDescValues(tt::DescriptorLoadOp loadOp) { + SmallVector terminals; + SmallVector worklist = collectCanonicalLandingMemDescRoots(loadOp); + llvm::SmallPtrSet visited; + llvm::SmallPtrSet seenTerminals; + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current.getAsOpaquePointer()).second) + continue; + if (!isa(current.getType())) + continue; + + bool hasViewUser = false; + bool hasNonViewUser = false; + for (Operation *user : current.getUsers()) { + if (isCanonicalLandingMemDescViewOp(user)) { + hasViewUser = true; + for (Value result : user->getResults()) + worklist.push_back(result); + } else { + hasNonViewUser = true; + } + } + + if ((!hasViewUser || hasNonViewUser) && + seenTerminals.insert(current.getAsOpaquePointer()).second) { + terminals.push_back(current); + } + } + return terminals; +} + +inline std::optional +getUniqueCanonicalLandingRootMemDescType(tt::DescriptorLoadOp loadOp) { + std::optional uniqueTy; + for (Value root : collectCanonicalLandingMemDescRoots(loadOp)) { + auto rootTy = dyn_cast(root.getType()); + if (!rootTy) + continue; + if (!uniqueTy) { + uniqueTy = rootTy; + continue; + } + if (!areMemDescTypesCompatible(*uniqueTy, rootTy)) + return std::nullopt; + } + return uniqueTy; +} + +inline std::optional +getUniqueCanonicalLandingTerminalMemDescType(tt::DescriptorLoadOp loadOp) { + std::optional uniqueTy; + for (Value terminal : collectCanonicalLandingTerminalMemDescValues(loadOp)) { + auto terminalTy = dyn_cast(terminal.getType()); + if (!terminalTy) + continue; + if (!uniqueTy) { + uniqueTy = terminalTy; + continue; + } + if (!areMemDescTypesCompatible(*uniqueTy, terminalTy)) + return std::nullopt; + } + return uniqueTy; +} + +inline std::optional +getUniqueCanonicalLandingMemDescType(tt::DescriptorLoadOp loadOp) { + return getUniqueCanonicalLandingTerminalMemDescType(loadOp); +} + +inline bool hasCanonicalSharedLanding(tt::DescriptorLoadOp loadOp) { + auto landingTy = getUniqueCanonicalLandingRootMemDescType(loadOp); + return landingTy && + isa_and_nonnull(landingTy->getEncoding()); +} + +inline SmallVector collectConnectedCanonicalMemDescChain(Value seed) { + SmallVector chain; + SmallVector worklist{seed}; + llvm::SmallPtrSet visited; + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current.getAsOpaquePointer()).second) + continue; + if (!isa(current.getType())) + continue; + chain.push_back(current); + if (Operation *defOp = current.getDefiningOp()) { + if (isCanonicalLandingMemDescViewOp(defOp)) { + if (auto srcOperand = + dyn_cast>(defOp->getOperand(0))) { + worklist.push_back(srcOperand); + } + } + } + for (Operation *user : current.getUsers()) { + if (!isCanonicalLandingMemDescViewOp(user)) + continue; + for (Value result : user->getResults()) + worklist.push_back(result); + } + } + return chain; +} + +inline std::optional +getUniqueCanonicalLandingSqmmaMemDescType(tt::DescriptorLoadOp loadOp) { + std::optional uniqueTy; + auto expectedShape = loadOp.getType().getShape(); + Type expectedElemTy = loadOp.getType().getElementType(); + + for (Value root : collectCanonicalLandingMemDescRoots(loadOp)) { + for (Value memDescValue : collectConnectedCanonicalMemDescChain(root)) { + auto memDescTy = dyn_cast(memDescValue.getType()); + Operation *defOp = memDescValue.getDefiningOp(); + if (!memDescTy || !defOp || !hasSqmmaOpIdxAttr(defOp)) + continue; + if (memDescTy.getShape() != expectedShape || + memDescTy.getElementType() != expectedElemTy) + continue; + + if (!uniqueTy) { + uniqueTy = memDescTy; + continue; + } + if (!areMemDescTypesCompatible(*uniqueTy, memDescTy)) + return std::nullopt; + } + } + return uniqueTy; +} + +inline Operation * +getUniqueCanonicalLandingSqmmaAttrSource(tt::DescriptorLoadOp loadOp) { + Operation *source = nullptr; + auto mergeSource = [&](Operation *candidate) -> bool { + if (!hasSqmmaOpIdxAttr(candidate)) + return true; + if (!source) { + source = candidate; + return true; + } + for (auto name : kSqmmaAttrNames) { + if (source->getAttr(name) != candidate->getAttr(name)) + return false; + } + return true; + }; + + for (Value root : collectCanonicalLandingMemDescRoots(loadOp)) { + if (auto rootOp = root.getDefiningOp()) { + if (!mergeSource(rootOp)) + return nullptr; + } + for (Value memDescValue : collectConnectedCanonicalMemDescChain(root)) { + for (Operation *user : memDescValue.getUsers()) { + if (!mergeSource(user)) + return nullptr; + } + } + } + return source; +} + +inline bool copyCanonicalLandingSqmmaAttrs(tt::DescriptorLoadOp loadOp, + Operation *dst) { + Operation *source = getUniqueCanonicalLandingSqmmaAttrSource(loadOp); + if (!source && hasSqmmaOpIdxAttr(loadOp.getOperation())) + source = loadOp.getOperation(); + if (!source) + return false; + copySqmmaAttrs(source, dst); + return true; +} + +inline unsigned getDescriptorContractNumCTAs(Operation *op) { + if (auto func = op ? op->getParentOfType() : tt::FuncOp()) + return std::max(1u, static_cast(ttg::lookupNumCTAs(func))); + return 1; +} + +inline FailureOr +buildDescriptorLandingMemDescType(Operation *op, tt::TensorDescType descTy, + RankedTensorType tensorTy, + bool mutableMemory) { + auto descBlockTy = descTy.getSignlessBlockType(); + auto swizzled = dyn_cast_or_null( + descBlockTy.getEncoding()); + if (!swizzled) { + if (op) + op->emitError("expected descriptor block type to be normalized to " + "canonical swizzled shared encoding"); + return failure(); + } + auto normalizedEncoding = + tryMapTMECompatibleSharedEncodingToCanonicalSwizzled( + op, tensorTy, swizzled, tensorTy.getShape(), + getDescriptorContractNumCTAs(op)); + if (!normalizedEncoding) { + if (op) + op->emitError("unable to project descriptor shared encoding onto " + "descriptor load/store tensor rank"); + return failure(); + } + return ttg::MemDescType::get( + tensorTy.getShape(), tensorTy.getElementType(), *normalizedEncoding, + ttg::SharedMemorySpaceAttr::get(op->getContext()), mutableMemory); +} + +inline std::optional tryNormalizeCanonicalLandingMemDescType( + Operation *op, ttg::MemDescType memDescTy, bool mutableMemory) { + auto tensorTy = + RankedTensorType::get(memDescTy.getShape(), memDescTy.getElementType()); + auto normalizedEncoding = + tryMapTMECompatibleSharedEncodingToCanonicalSwizzled( + op, tensorTy, memDescTy.getEncoding(), memDescTy.getShape(), + getDescriptorContractNumCTAs(op)); + if (!normalizedEncoding) + return std::nullopt; + return ttg::MemDescType::get(memDescTy.getShape(), memDescTy.getElementType(), + *normalizedEncoding, memDescTy.getMemorySpace(), + mutableMemory, memDescTy.getAllocShape()); +} + +inline FailureOr +resolveDescriptorLoadLandingMemDescType(tt::DescriptorLoadOp loadOp) { + auto authoritative = buildDescriptorLandingMemDescType( + loadOp.getOperation(), loadOp.getDesc().getType(), loadOp.getType(), + /*mutableMemory=*/true); + if (failed(authoritative)) + return failure(); + auto canonicalTy = getUniqueCanonicalLandingMemDescType(loadOp); + if (!canonicalTy) + return authoritative; + auto normalizedCanonical = tryNormalizeCanonicalLandingMemDescType( + loadOp.getOperation(), *canonicalTy, + /*mutableMemory=*/true); + if (!normalizedCanonical) + return authoritative; + if (!areMemDescTypesEquivalentForDescriptorContract(*normalizedCanonical, + *authoritative)) + return authoritative; + return *normalizedCanonical; +} + +inline FailureOr resolveDescriptorStoreLandingMemDescType( + tt::DescriptorStoreLikeOpInterface storeOp, bool mutableMemory) { + return buildDescriptorLandingMemDescType( + storeOp.getOperation(), storeOp.getDesc().getType(), + storeOp.getSrc().getType(), mutableMemory); +} + +inline Value adaptMemDescValue(RewriterBase &rewriter, Location loc, + Value value, ttg::MemDescType targetTy) { + auto srcTy = dyn_cast(value.getType()); + if (!srcTy) + return {}; + if (srcTy == targetTy) + return value; + if (!areMemDescTypesCompatible(srcTy, targetTy) && + !areMemDescTypesLayoutEquivalent(srcTy, targetTy)) + return {}; + return ttg::MemDescReinterpretOp::create(rewriter, loc, targetTy, value); +} + +inline Value findReusableLocalAllocForSource(Value source, + ttg::MemDescType targetTy) { + for (Operation *user : source.getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc || localAlloc.getSrc() != source) + continue; + auto allocTy = dyn_cast(localAlloc.getResult().getType()); + if (allocTy != targetTy) + continue; + return localAlloc.getResult(); + } + return {}; +} + +inline Value materializeTransformedMemDescForTarget(RewriterBase &rewriter, + tt::TransOp transOp, + Value sourceMemDesc, + ttg::MemDescType targetTy) { + SmallVector transposeOrder(transOp.getOrder().begin(), + transOp.getOrder().end()); + Value transformed = ttg::MemDescTransOp::create( + rewriter, transOp.getLoc(), sourceMemDesc, transposeOrder); + if (transformed.getType() == targetTy) + return transformed; + Value adapted = + adaptMemDescValue(rewriter, transOp.getLoc(), transformed, targetTy); + if (adapted) + return adapted; + transformed.getDefiningOp()->erase(); + return {}; +} + +inline Value materializeReshapedMemDescForTarget(RewriterBase &rewriter, + tt::ReshapeOp reshapeOp, + Value sourceMemDesc, + ttg::MemDescType targetTy) { + Value transformed = ttg::MemDescReshapeOp::create( + rewriter, reshapeOp.getLoc(), sourceMemDesc, targetTy.getShape()); + if (transformed.getType() == targetTy) + return transformed; + Value adapted = + adaptMemDescValue(rewriter, reshapeOp.getLoc(), transformed, targetTy); + if (adapted) + return adapted; + transformed.getDefiningOp()->erase(); + return {}; +} + +inline bool replaceTensorLocalAllocWithMemDesc(RewriterBase &rewriter, + Operation *user, + Value sourceMemDesc) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + return false; + auto targetTy = dyn_cast(localAlloc.getResult().getType()); + if (!targetTy) + return false; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + Value replacement = + adaptMemDescValue(rewriter, localAlloc.getLoc(), sourceMemDesc, targetTy); + if (!replacement) + return false; + rewriter.replaceOp(localAlloc, replacement); + return true; +} + +inline bool tryReplaceTensorUserWithMemDesc(RewriterBase &rewriter, + Value tensorValue, + Value sourceMemDesc, + Operation *user) { + if (replaceTensorLocalAllocWithMemDesc(rewriter, user, sourceMemDesc)) + return true; + + if (auto transOp = dyn_cast(user)) { + bool changed = false; + for (Operation *transUser : + llvm::make_early_inc_range(transOp->getUsers())) { + auto localAlloc = dyn_cast(transUser); + if (!localAlloc) + continue; + auto targetTy = dyn_cast(localAlloc.getType()); + if (!targetTy) + continue; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + Value replacement = materializeTransformedMemDescForTarget( + rewriter, transOp, sourceMemDesc, targetTy); + if (!replacement) + continue; + rewriter.replaceOp(localAlloc, replacement); + changed = true; + } + return changed; + } + + if (auto reshapeOp = dyn_cast(user)) { + bool changed = false; + for (Operation *reshapeUser : + llvm::make_early_inc_range(reshapeOp->getUsers())) { + auto localAlloc = dyn_cast(reshapeUser); + if (!localAlloc) + continue; + auto targetTy = dyn_cast(localAlloc.getType()); + if (!targetTy) + continue; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + Value replacement = materializeReshapedMemDescForTarget( + rewriter, reshapeOp, sourceMemDesc, targetTy); + if (!replacement) + continue; + rewriter.replaceOp(localAlloc, replacement); + changed = true; + } + return changed; + } + + (void)tensorValue; + return false; +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_MEMDESC_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/SqmmaAttrUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/SqmmaAttrUtils.h new file mode 100644 index 0000000000..11cb31e532 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/SqmmaAttrUtils.h @@ -0,0 +1,67 @@ +#ifndef TRITONMUSA_COMMON_SQMMA_ATTR_UTILS_H +#define TRITONMUSA_COMMON_SQMMA_ATTR_UTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir::triton::musa { + +inline constexpr llvm::StringLiteral kSqmmaOpIdxAttr = "sqmma.op_idx"; +inline constexpr llvm::StringLiteral kSqmmaElemBytesAttr = "sqmma.elem_bytes"; +inline constexpr llvm::StringLiteral kSqmmaRowMajorAttr = "sqmma.row_major"; + +inline constexpr std::array kSqmmaAttrNames = { + kSqmmaOpIdxAttr, kSqmmaElemBytesAttr, kSqmmaRowMajorAttr}; + +inline std::optional getIntAttr(Operation *op, StringRef name) { + if (auto intAttr = op->getAttrOfType(name)) + return intAttr.getInt(); + return std::nullopt; +} + +inline std::optional getSqmmaOpIdx(Operation *op) { + return getIntAttr(op, kSqmmaOpIdxAttr); +} + +inline std::optional getSqmmaElemBytes(Operation *op) { + return getIntAttr(op, kSqmmaElemBytesAttr); +} + +inline bool getSqmmaRowMajor(Operation *op, bool defaultValue) { + if (auto boolAttr = op->getAttrOfType(kSqmmaRowMajorAttr)) + return boolAttr.getValue(); + if (auto intAttr = op->getAttrOfType(kSqmmaRowMajorAttr)) + return intAttr.getInt() != 0; + return defaultValue; +} + +inline bool hasSqmmaOpIdxAttr(Operation *op) { + return getSqmmaOpIdx(op).has_value(); +} + +inline void copySqmmaAttrs(Operation *src, Operation *dst) { + for (auto name : kSqmmaAttrNames) { + if (Attribute attr = src->getAttr(name)) + dst->setAttr(name, attr); + } +} + +inline void setSqmmaAttrs(Operation *op, int64_t opIdx, int64_t elemBytes, + bool rowMajor) { + auto *ctx = op->getContext(); + auto i32Ty = IntegerType::get(ctx, 32); + auto opIdxAttr = IntegerAttr::get(i32Ty, opIdx); + auto elemBytesAttr = IntegerAttr::get(i32Ty, elemBytes); + op->setAttr(kSqmmaOpIdxAttr, opIdxAttr); + op->setAttr(kSqmmaElemBytesAttr, elemBytesAttr); + op->setAttr(kSqmmaRowMajorAttr, BoolAttr::get(ctx, rowMajor)); +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_SQMMA_ATTR_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSACommon/TMEUtils.h b/third_party/mthreads/musa/include/TritonMUSACommon/TMEUtils.h new file mode 100644 index 0000000000..96c7c66126 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSACommon/TMEUtils.h @@ -0,0 +1,779 @@ +#ifndef TRITONMUSA_COMMON_TME_UTILS_H +#define TRITONMUSA_COMMON_TME_UTILS_H + +#include "Dialect/MUSA/IR/Dialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include +#include +#include + +namespace mlir::triton::musa { + +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +enum class TMECopyKind { + GlobalToLocal, + LocalToGlobal, +}; + +struct ResolvedTMESwizzleConfig { + TMESwizzleGranularity swizzleGranularity = TMESwizzleGranularity::SG_NONE; + TMESwizzleStride swizzleStride = TMESwizzleStride::SS_256B; + TMESwizzleLine swizzleLine = TMESwizzleLine::SL_256B; +}; + +struct FinalTMECopyConfig { + DenseI32ArrayAttr blockShape; + TMESwizzleGranularity swizzleGranularity = TMESwizzleGranularity::SG_NONE; + TMESwizzleStride swizzleStride = TMESwizzleStride::SS_256B; + TMESwizzleLine swizzleLine = TMESwizzleLine::SL_256B; + TMEL2CachePolicy cachePolicy = TMEL2CachePolicy::NEW_ALLOC; + TMEPersistence innerPersistence = TMEPersistence::CACHE_NORMAL; + TMEPersistence outerPersistence = TMEPersistence::CACHE_NORMAL; + std::optional prefetchSize; +}; + +inline std::optional toI32(int64_t value) { + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) + return std::nullopt; + return static_cast(value); +} + +template +Value materializeI32Value(Value value, Location loc, BuilderT &builder) { + Type ty = value.getType(); + if (ty.isInteger(32)) + return value; + if (isa(ty)) + return arith::IndexCastOp::create(builder, loc, builder.getI32Type(), + value); + if (auto intTy = dyn_cast(ty)) { + if (intTy.getWidth() > 32) + return arith::TruncIOp::create(builder, loc, builder.getI32Type(), value); + if (intTy.getWidth() < 32) + return arith::ExtSIOp::create(builder, loc, builder.getI32Type(), value); + return value; + } + return {}; +} + +inline FailureOr +materializeTMEBlockShapeAttr(MLIRContext *ctx, ArrayRef dims) { + if (dims.empty() || dims.size() > 5) + return failure(); + + SmallVector i32Dims; + i32Dims.reserve(dims.size()); + for (int64_t dimVal : dims) { + if (dimVal <= 0) + return failure(); + auto dim = toI32(dimVal); + if (!dim) + return failure(); + i32Dims.push_back(*dim); + } + return DenseI32ArrayAttr::get(ctx, i32Dims); +} + +template +FailureOr> +materializeTMECoordValues(Location loc, ValueRange indices, BuilderT &builder) { + unsigned rank = indices.size(); + if (rank == 0 || rank > 5) + return failure(); + + SmallVector coords; + coords.reserve(rank); + for (Value val : indices) { + Value i32Val = materializeI32Value(val, loc, builder); + if (!i32Val) + return failure(); + coords.push_back(i32Val); + } + return coords; +} + +inline std::optional +inferElemBytesFromMemDescType(ttg::MemDescType ty) { + int bitWidth = ty.getElementTypeBitWidth(); + if (bitWidth <= 0) + return std::nullopt; + return static_cast((bitWidth + 7) / 8); +} + +inline std::optional inferRowMajorFromMemDescType(ttg::MemDescType ty) { + auto order = ttg::getOrder(ty); + if (order.empty()) + return std::nullopt; + return static_cast(order.front() + 1) == + static_cast(ty.getShape().size()); +} + +inline int64_t getSwizzleGranularityBytes(TMESwizzleGranularity value) { + switch (value) { + case TMESwizzleGranularity::SG_NONE: + return 0; + case TMESwizzleGranularity::SG_16B: + return 16; + case TMESwizzleGranularity::SG_32B: + return 32; + case TMESwizzleGranularity::SG_64B: + return 64; + case TMESwizzleGranularity::SG_128B: + return 128; + } + return 0; +} + +inline int64_t getSwizzleStrideBytes(TMESwizzleStride value) { + switch (value) { + case TMESwizzleStride::SS_32B: + return 32; + case TMESwizzleStride::SS_64B: + return 64; + case TMESwizzleStride::SS_128B: + return 128; + case TMESwizzleStride::SS_256B: + return 256; + } + return 0; +} + +inline int64_t getSwizzleLineBytes(TMESwizzleLine value) { + switch (value) { + case TMESwizzleLine::SL_128B: + return 128; + case TMESwizzleLine::SL_256B: + return 256; + } + return 0; +} + +inline bool isValidTMESwizzleConfig(TMESwizzleGranularity granularity, + TMESwizzleStride stride, + TMESwizzleLine line) { + int64_t granularityBytes = getSwizzleGranularityBytes(granularity); + int64_t strideBytes = getSwizzleStrideBytes(stride); + int64_t lineBytes = getSwizzleLineBytes(line); + return strideBytes >= granularityBytes && strideBytes <= lineBytes; +} + +inline SmallVector getDefaultTMEOrder(unsigned rank) { + SmallVector order; + order.reserve(rank); + for (int i = static_cast(rank) - 1; i >= 0; --i) + order.push_back(i); + return order; +} + +inline ttg::CGAEncodingAttr getDefaultTMECGALayout(MLIRContext *ctx, + unsigned rank, + unsigned numCTAs, + ArrayRef order) { + SmallVector ctasPerCGA(rank, 1); + if (!ctasPerCGA.empty()) + ctasPerCGA.back() = std::max(1u, numCTAs); + return ttg::CGAEncodingAttr::fromSplitParams(ctx, ctasPerCGA, ctasPerCGA, + order); +} + +inline ttg::CGAEncodingAttr canonicalizeTMECGALayoutForShape( + MLIRContext *ctx, ttg::CGAEncodingAttr cgaLayout, ArrayRef shape, + ArrayRef order, unsigned numCTAs) { + if (!cgaLayout) { + assert(ctx && "missing MLIR context for default TME CGA layout"); + return getDefaultTMECGALayout(ctx, shape.size(), numCTAs, order); + } + if (cgaLayout.getRank() != shape.size()) + return ttng::updateCGALayoutForShape(cgaLayout, shape); + return cgaLayout; +} + +inline ttg::SwizzledSharedEncodingAttr getDefaultTMECompatibleSharedEncoding( + RankedTensorType tensorTy, ttg::CGAEncodingAttr cgaLayout = {}, + ArrayRef usageShape = {}, unsigned numCTAs = 1) { + auto *ctx = tensorTy.getContext(); + SmallVector shape = + usageShape.empty() + ? SmallVector(tensorTy.getShape()) + : SmallVector(usageShape.begin(), usageShape.end()); + auto order = getDefaultTMEOrder(tensorTy.getRank()); + cgaLayout = + cgaLayout + ? canonicalizeTMECGALayoutForShape(ctx, cgaLayout, shape, order, + numCTAs) + : getDefaultTMECGALayout(ctx, tensorTy.getRank(), numCTAs, order); + return ttg::SwizzledSharedEncodingAttr::get(ctx, /*vec=*/1, /*perPhase=*/1, + /*maxPhase=*/1, order, cgaLayout); +} + +inline std::optional +tryMapTMECompatibleSharedEncodingToCanonicalSwizzled( + Operation *op, RankedTensorType tensorTy, Attribute encoding, + ArrayRef usageShape, unsigned numCTAs) { + auto *ctx = tensorTy.getContext(); + ttg::SwizzledSharedEncodingAttr candidate; + if (auto swizzled = + dyn_cast_or_null(encoding)) { + candidate = swizzled; + } else if (auto nvmma = + dyn_cast_or_null(encoding)) { + if (nvmma.getTransposed() || nvmma.getFp4Padded()) + return std::nullopt; + SmallVector order = ttg::getOrder(nvmma, usageShape); + candidate = ttg::SwizzledSharedEncodingAttr::get( + ctx, nvmma.getVec(), nvmma.getPerPhase(), nvmma.getMaxPhase(), order, + nvmma.getCGALayout()); + } else { + return std::nullopt; + } + + auto updated = + cast(ttng::updateEncodingForShape( + op, cast(candidate), tensorTy)); + auto shape = usageShape.empty() ? tensorTy.getShape() : usageShape; + auto cgaLayout = + updated.getCGALayout() + ? canonicalizeTMECGALayoutForShape(ctx, updated.getCGALayout(), shape, + updated.getOrder(), numCTAs) + : getDefaultTMECGALayout(ctx, tensorTy.getRank(), numCTAs, + updated.getOrder()); + return ttg::SwizzledSharedEncodingAttr::get( + ctx, updated.getVec(), updated.getPerPhase(), updated.getMaxPhase(), + updated.getOrder(), cgaLayout); +} + +inline int64_t applyPH1TMESwizzleToByteAddress(int64_t addrBytes, + TMESwizzleGranularity sg, + TMESwizzleStride ss, + TMESwizzleLine sl) { + if (sg == TMESwizzleGranularity::SG_NONE) + return addrBytes; + int64_t sgBytes = getSwizzleGranularityBytes(sg); + int64_t ssBytes = getSwizzleStrideBytes(ss); + int64_t slBytes = getSwizzleLineBytes(sl); + assert(sgBytes > 0 && ssBytes > 0 && slBytes > 0); + int64_t lineOffset = addrBytes % slBytes; + int64_t lineId = addrBytes / slBytes; + int64_t swizzleGroup = ssBytes / sgBytes; + int64_t swizzleLineId = lineId % swizzleGroup; + int64_t sectorInLine = lineOffset / sgBytes; + int64_t offsetInSector = lineOffset % sgBytes; + int64_t targetSectorInLine = sectorInLine ^ swizzleLineId; + return lineId * slBytes + targetSectorInLine * sgBytes + offsetInSector; +} + +inline SmallVector +decodeLinearOffsetToCoords(int64_t linearOffset, ArrayRef shape, + ArrayRef order) { + SmallVector coords(shape.size(), 0); + for (unsigned dim : order) { + int64_t dimSize = shape[dim]; + if (dimSize <= 0) + return {}; + coords[dim] = static_cast(linearOffset % dimSize); + linearOffset /= dimSize; + } + return coords; +} + +struct PH1TMELeadingDimGrouping { + int64_t numGroups = 1; + int64_t elemsPerGroupInLeadingDim = 0; + int64_t elemsPerGroup = 0; +}; + +inline FailureOr +getPH1TMELeadingDimGrouping(ArrayRef shape, ArrayRef order, + unsigned elemBytes) { + if (shape.size() != 2 || order.size() < 2) + return PH1TMELeadingDimGrouping{}; + + int64_t leadingDim = shape[order[0]]; + int64_t leadingWidthBytes = leadingDim * static_cast(elemBytes); + if (leadingDim <= 0 || leadingWidthBytes <= 0) + return failure(); + + if (leadingWidthBytes <= 256) { + return PH1TMELeadingDimGrouping{/*numGroups=*/1, + /*elemsPerGroupInLeadingDim=*/leadingDim, + /*elemsPerGroup=*/shape[order[1]] * + leadingDim}; + } + + int64_t maxColsPerGroup = 256 / static_cast(elemBytes); + if (maxColsPerGroup <= 0) + return failure(); + int64_t numGroups = (leadingDim + maxColsPerGroup - 1) / maxColsPerGroup; + if (numGroups <= 0 || (leadingDim % numGroups) != 0) + return failure(); + + int64_t elemsPerGroupInLeadingDim = leadingDim / numGroups; + int64_t totalElems = 1; + for (int64_t dim : shape) { + if (dim <= 0) + return failure(); + totalElems *= dim; + } + if ((totalElems % numGroups) != 0) + return failure(); + + return PH1TMELeadingDimGrouping{numGroups, elemsPerGroupInLeadingDim, + totalElems / numGroups}; +} + +inline SmallVector +decodePH1TMELinearOffsetToCoords(int64_t linearOffset, ArrayRef shape, + ArrayRef order, unsigned elemBytes) { + auto grouping = getPH1TMELeadingDimGrouping(shape, order, elemBytes); + if (failed(grouping)) + return {}; + if (grouping->numGroups == 1) + return decodeLinearOffsetToCoords(linearOffset, shape, order); + + SmallVector coords(shape.size(), 0); + int64_t groupId = linearOffset / grouping->elemsPerGroup; + int64_t offsetInGroup = linearOffset % grouping->elemsPerGroup; + int64_t row = offsetInGroup / grouping->elemsPerGroupInLeadingDim; + int64_t colInGroup = offsetInGroup % grouping->elemsPerGroupInLeadingDim; + int64_t col = groupId * grouping->elemsPerGroupInLeadingDim + colInGroup; + + if (groupId < 0 || row < 0 || col < 0 || row >= shape[order[1]] || + col >= shape[order[0]]) + return {}; + + coords[order[0]] = static_cast(col); + coords[order[1]] = static_cast(row); + return coords; +} + +inline FailureOr linearizePH1TMELinearCoords(ArrayRef coords, + ArrayRef shape, + ArrayRef order, + unsigned elemBytes) { + auto grouping = getPH1TMELeadingDimGrouping(shape, order, elemBytes); + if (failed(grouping)) + return failure(); + if (coords.size() != shape.size() || order.size() < 2) + return failure(); + + int64_t leading = coords[order[0]]; + int64_t row = coords[order[1]]; + if (leading < 0 || row < 0 || leading >= shape[order[0]] || + row >= shape[order[1]]) + return failure(); + + if (grouping->numGroups == 1) + return row * shape[order[0]] + leading; + + int64_t groupId = leading / grouping->elemsPerGroupInLeadingDim; + int64_t colInGroup = leading % grouping->elemsPerGroupInLeadingDim; + return groupId * grouping->elemsPerGroup + + row * grouping->elemsPerGroupInLeadingDim + colInGroup; +} + +inline FailureOr linearizePH1TMELinearCoords(TritonLLVMOpBuilder &b, + ArrayRef coords, + ArrayRef shape, + ArrayRef order, + unsigned elemBytes) { + auto grouping = getPH1TMELeadingDimGrouping(shape, order, elemBytes); + if (failed(grouping)) + return failure(); + if (coords.size() != shape.size() || order.size() < 2) + return failure(); + + Value leading = coords[order[0]]; + Value row = coords[order[1]]; + if (grouping->numGroups == 1) { + Value linearOffset = b.add( + b.mul(row, b.i32_val(static_cast(shape[order[0]]))), leading); + return linearOffset; + } + + Value elemsPerGroupInLeadingDim = + b.i32_val(static_cast(grouping->elemsPerGroupInLeadingDim)); + Value elemsPerGroup = + b.i32_val(static_cast(grouping->elemsPerGroup)); + Value groupId = b.udiv(leading, elemsPerGroupInLeadingDim); + Value colInGroup = b.urem(leading, elemsPerGroupInLeadingDim); + Value linearOffset = b.add(b.add(b.mul(groupId, elemsPerGroup), + b.mul(row, elemsPerGroupInLeadingDim)), + colInGroup); + return linearOffset; +} + +inline LinearLayout +combineTMECtaLayoutWithCGA(LinearLayout ctaLayout, + ttg::CGAEncodingAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + auto outDimNames = triton::standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) + labeledShape[dim] = size; + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(cgaLayoutAttr.getLinearLayout(), labeledShape) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + llvm::SmallDenseMap ctaShape; + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout result = + (ctaLayout * cgaLayout) + .transposeOuts(triton::standardOutDimNames(ctx, rank)); + return result; +} + +inline LinearLayout buildPH1TMESharedLinearLayout( + ArrayRef shape, ArrayRef allocShape, unsigned elemBytes, + ArrayRef order, ttg::CGAEncodingAttr cgaLayout, + TMESwizzleGranularity sg, TMESwizzleStride ss, TMESwizzleLine sl) { + assert(!shape.empty() && "PH1 TME shared layout requires non-zero rank"); + assert(cgaLayout && "PH1 TME shared layout requires canonical CGA layout"); + auto physicalShape = + allocShape.empty() + ? SmallVector(shape.begin(), shape.end()) + : SmallVector(allocShape.take_back(shape.size()).begin(), + allocShape.take_back(shape.size()).end()); + MLIRContext *ctx = cgaLayout.getContext(); + auto shapePerCTA = + ttg::getShapePerCTA(cgaLayout.getCTASplitNum(), physicalShape); + + int64_t totalElems = 1; + for (int64_t dim : shapePerCTA) + totalElems *= dim; + + auto outDimNames = triton::standardOutDimNames(ctx, shape.size()); + SmallVector> offsetBases; + for (int64_t elemOffset = 1; elemOffset < totalElems; elemOffset <<= 1) { + int64_t logicalByteOffset = + applyPH1TMESwizzleToByteAddress(elemOffset * elemBytes, sg, ss, sl); + assert((logicalByteOffset % elemBytes) == 0 && + "PH1 TME swizzle must preserve element alignment"); + auto coords = decodePH1TMELinearOffsetToCoords( + logicalByteOffset / elemBytes, shapePerCTA, order, elemBytes); + offsetBases.emplace_back(coords.begin(), coords.end()); + } + + SmallVector> outDims; + outDims.reserve(shape.size()); + for (auto [dimName, dimSize] : llvm::zip(outDimNames, shapePerCTA)) + outDims.emplace_back(dimName, static_cast(dimSize)); + + LinearLayout::BasesT ctaBases; + ctaBases[StringAttr::get(ctx, "offset")] = + std::vector>(offsetBases.begin(), offsetBases.end()); + LinearLayout ctaLayout(ctaBases, outDims, /*requireSurjective=*/false); + return combineTMECtaLayoutWithCGA(ctaLayout, cgaLayout, physicalShape); +} + +inline std::optional +getCanonicalPH1TMESwizzleConfigForGranularityBytes(int64_t sgBytes) { + switch (sgBytes) { + case 16: + return ResolvedTMESwizzleConfig{TMESwizzleGranularity::SG_16B, + TMESwizzleStride::SS_256B, + TMESwizzleLine::SL_256B}; + case 32: + return ResolvedTMESwizzleConfig{TMESwizzleGranularity::SG_32B, + TMESwizzleStride::SS_256B, + TMESwizzleLine::SL_256B}; + case 64: + return ResolvedTMESwizzleConfig{TMESwizzleGranularity::SG_64B, + TMESwizzleStride::SS_256B, + TMESwizzleLine::SL_256B}; + default: + return std::nullopt; + } +} + +inline FailureOr +resolveCanonicalPH1TMESharedCarrierConfig(ttg::MemDescType localType) { + auto swizzled = + dyn_cast(localType.getEncoding()); + auto maybeElemBytes = inferElemBytesFromMemDescType(localType); + if (!swizzled || !maybeElemBytes || *maybeElemBytes <= 0) + return failure(); + + auto order = swizzled.getOrder(); + if (localType.getShape().size() != 2 || order.size() < 2) + return failure(); + + if (swizzled.getVec() == 1 && swizzled.getPerPhase() == 1 && + swizzled.getMaxPhase() == 1) { + return ResolvedTMESwizzleConfig{TMESwizzleGranularity::SG_NONE, + TMESwizzleStride::SS_256B, + TMESwizzleLine::SL_256B}; + } + + constexpr int64_t kLineBytes = 256; + int64_t elemBytes = *maybeElemBytes; + int64_t leadingWidthBytes = localType.getShape()[order.front()] * elemBytes; + if (leadingWidthBytes <= 0) + return failure(); + + int64_t sgBytes = 0; + if (swizzled.getPerPhase() > 1) { + if ((kLineBytes % swizzled.getPerPhase()) != 0 || + leadingWidthBytes != (kLineBytes / swizzled.getPerPhase())) + return failure(); + sgBytes = kLineBytes / swizzled.getMaxPhase(); + int64_t expectedVec = sgBytes / elemBytes; + if (expectedVec != swizzled.getVec()) + return failure(); + } else { + if (leadingWidthBytes < kLineBytes || (leadingWidthBytes % kLineBytes) != 0) + return failure(); + int64_t factor = leadingWidthBytes / kLineBytes; + if (factor <= 0 || !llvm::isPowerOf2_64(static_cast(factor))) + return failure(); + sgBytes = kLineBytes / (factor * swizzled.getMaxPhase()); + int64_t expectedVec = factor * (sgBytes / elemBytes); + if (sgBytes <= 0 || expectedVec != swizzled.getVec()) + return failure(); + } + + auto config = getCanonicalPH1TMESwizzleConfigForGranularityBytes(sgBytes); + if (!config) + return failure(); + return *config; +} + +inline LinearLayout +getMUSASharedLinearLayoutOrGeneric(ttg::MemDescType localType) { + auto maybeElemBytes = inferElemBytesFromMemDescType(localType); + auto localEncoding = + dyn_cast(localType.getEncoding()); + if (!maybeElemBytes || *maybeElemBytes <= 0 || !localEncoding) + return ttg::toLinearLayout(localType); + + auto order = ttg::getOrder(localType); + if (order.empty()) + return ttg::toLinearLayout(localType); + auto cgaLayout = ttg::getCGALayout(localEncoding); + if (!cgaLayout) + cgaLayout = getDefaultTMECGALayout(localType.getContext(), + localType.getShape().size(), 1, order); + + auto carrierConfig = resolveCanonicalPH1TMESharedCarrierConfig(localType); + if (failed(carrierConfig)) + return ttg::toLinearLayout(localType); + + return buildPH1TMESharedLinearLayout( + localType.getShape(), localType.getAllocShape(), + static_cast(*maybeElemBytes), order, cgaLayout, + carrierConfig->swizzleGranularity, carrierConfig->swizzleStride, + carrierConfig->swizzleLine); +} + +inline FailureOr +resolveTMESwizzleConfigFromEncoding(ttg::MemDescType localType) { + auto localEncoding = + dyn_cast(localType.getEncoding()); + if (!localEncoding) + return failure(); + auto maybeElemBytes = inferElemBytesFromMemDescType(localType); + if (!maybeElemBytes || *maybeElemBytes <= 0) + return failure(); + auto order = ttg::getOrder(localType); + if (order.empty()) + return failure(); + auto cgaLayout = ttg::getCGALayout(localEncoding); + if (!cgaLayout) + cgaLayout = getDefaultTMECGALayout(localType.getContext(), + localType.getShape().size(), 1, order); + + if (auto carrierConfig = resolveCanonicalPH1TMESharedCarrierConfig(localType); + succeeded(carrierConfig)) + return *carrierConfig; + + auto targetLayout = getMUSASharedLinearLayoutOrGeneric(localType); + auto sharedSpace = ttg::SharedMemorySpaceAttr::get(localType.getContext()); + auto canonicalNoSwizzle = ttg::MemDescType::get( + localType.getShape(), localType.getElementType(), + ttg::SwizzledSharedEncodingAttr::get(localType.getContext(), 1, 1, 1, + order, cgaLayout), + sharedSpace, localType.getMutableMemory(), localType.getAllocShape()); + if (getMUSASharedLinearLayoutOrGeneric(canonicalNoSwizzle) == targetLayout) { + return ResolvedTMESwizzleConfig{TMESwizzleGranularity::SG_NONE, + TMESwizzleStride::SS_256B, + TMESwizzleLine::SL_256B}; + } + + SmallVector matches; + constexpr TMESwizzleGranularity granularityOptions[] = { + TMESwizzleGranularity::SG_16B, TMESwizzleGranularity::SG_32B, + TMESwizzleGranularity::SG_64B, TMESwizzleGranularity::SG_128B}; + constexpr TMESwizzleStride strideOptions[] = { + TMESwizzleStride::SS_32B, TMESwizzleStride::SS_64B, + TMESwizzleStride::SS_128B, TMESwizzleStride::SS_256B}; + constexpr TMESwizzleLine lineOptions[] = {TMESwizzleLine::SL_128B, + TMESwizzleLine::SL_256B}; + + for (TMESwizzleGranularity sg : granularityOptions) { + for (TMESwizzleStride ss : strideOptions) { + for (TMESwizzleLine sl : lineOptions) { + if (!isValidTMESwizzleConfig(sg, ss, sl)) + continue; + auto candidateLayout = buildPH1TMESharedLinearLayout( + localType.getShape(), localType.getAllocShape(), + static_cast(*maybeElemBytes), order, cgaLayout, sg, ss, + sl); + if (candidateLayout == targetLayout) { + matches.push_back(ResolvedTMESwizzleConfig{sg, ss, sl}); + if (matches.size() > 1) + return failure(); + } + } + } + } + if (matches.size() != 1) + return failure(); + return matches.front(); +} + +inline ttg::SwizzledSharedEncodingAttr +normalizeTMECompatibleSharedEncodingOrDefault( + Operation *op, RankedTensorType tensorTy, Attribute encoding, + ttg::CGAEncodingAttr preferredCGA, ArrayRef usageShape, + ArrayRef allocShape, unsigned numCTAs) { + auto shape = usageShape.empty() ? tensorTy.getShape() : usageShape; + auto tryCandidate = tryMapTMECompatibleSharedEncodingToCanonicalSwizzled( + op, tensorTy, encoding, shape, numCTAs); + if (tryCandidate) { + auto sharedSpace = ttg::SharedMemorySpaceAttr::get(tensorTy.getContext()); + auto candidateMemDesc = ttg::MemDescType::get( + shape, tensorTy.getElementType(), *tryCandidate, sharedSpace, + /*mutableMemory=*/true, allocShape.empty() ? shape : allocShape); + if (succeeded(resolveTMESwizzleConfigFromEncoding(candidateMemDesc))) + return *tryCandidate; + preferredCGA = preferredCGA ? preferredCGA : tryCandidate->getCGALayout(); + } + return getDefaultTMECompatibleSharedEncoding(tensorTy, preferredCGA, shape, + numCTAs); +} + +inline FailureOr +resolveFinalTMECopyConfig(ttg::MemDescType localType, + ArrayRef descBlockShape, TMECopyKind kind) { + auto blockShape = + materializeTMEBlockShapeAttr(localType.getContext(), descBlockShape); + if (failed(blockShape)) + return failure(); + + FinalTMECopyConfig config; + config.blockShape = *blockShape; + if (kind == TMECopyKind::GlobalToLocal) + config.prefetchSize = TMEPrefetchSize::SZ_NONE; + + auto swizzle = resolveTMESwizzleConfigFromEncoding(localType); + if (failed(swizzle)) + return failure(); + config.swizzleGranularity = swizzle->swizzleGranularity; + config.swizzleStride = swizzle->swizzleStride; + config.swizzleLine = swizzle->swizzleLine; + + if (!isValidTMESwizzleConfig(config.swizzleGranularity, config.swizzleStride, + config.swizzleLine)) + return failure(); + return config; +} + +template +AsyncTMECopyGlobalToLocalOp +createAsyncTMECopyGlobalToLocal(BuilderT &builder, Location loc, Value desc, + ValueRange coord, Value barId, Value result, + Value pred, const FinalTMECopyConfig &config) { + OperationState state(loc, AsyncTMECopyGlobalToLocalOp::getOperationName()); + state.addOperands(desc); + state.addOperands(coord); + state.addOperands({barId, result, pred}); + state.addAttribute("blockShape", config.blockShape); + state.addAttribute("swizzleGranularity", + TMESwizzleGranularityAttr::get(builder.getContext(), + config.swizzleGranularity)); + state.addAttribute( + "swizzleStride", + TMESwizzleStrideAttr::get(builder.getContext(), config.swizzleStride)); + state.addAttribute( + "swizzleLine", + TMESwizzleLineAttr::get(builder.getContext(), config.swizzleLine)); + state.addAttribute( + "prefetchSize", + TMEPrefetchSizeAttr::get(builder.getContext(), *config.prefetchSize)); + state.addAttribute( + "cachePolicy", + TMEL2CachePolicyAttr::get(builder.getContext(), config.cachePolicy)); + state.addAttribute( + "innerPersistence", + TMEPersistenceAttr::get(builder.getContext(), config.innerPersistence)); + state.addAttribute( + "outerPersistence", + TMEPersistenceAttr::get(builder.getContext(), config.outerPersistence)); + return cast(builder.create(state)); +} + +template +AsyncTMECopyLocalToGlobalOp +createAsyncTMECopyLocalToGlobal(BuilderT &builder, Location loc, Value desc, + ValueRange coord, Value src, Value pred, + const FinalTMECopyConfig &config) { + OperationState state(loc, AsyncTMECopyLocalToGlobalOp::getOperationName()); + state.addOperands(desc); + state.addOperands(coord); + state.addOperands({src, pred}); + state.addAttribute("blockShape", config.blockShape); + state.addAttribute("swizzleGranularity", + TMESwizzleGranularityAttr::get(builder.getContext(), + config.swizzleGranularity)); + state.addAttribute( + "swizzleStride", + TMESwizzleStrideAttr::get(builder.getContext(), config.swizzleStride)); + state.addAttribute( + "swizzleLine", + TMESwizzleLineAttr::get(builder.getContext(), config.swizzleLine)); + state.addAttribute( + "cachePolicy", + TMEL2CachePolicyAttr::get(builder.getContext(), config.cachePolicy)); + state.addAttribute( + "innerPersistence", + TMEPersistenceAttr::get(builder.getContext(), config.innerPersistence)); + state.addAttribute( + "outerPersistence", + TMEPersistenceAttr::get(builder.getContext(), config.outerPersistence)); + return cast(builder.create(state)); +} + +} // namespace mlir::triton::musa + +#endif // TRITONMUSA_COMMON_TME_UTILS_H diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Allocation.h b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Allocation.h new file mode 100644 index 0000000000..0b5f0925aa --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Allocation.h @@ -0,0 +1,25 @@ +#ifndef TRITON_CONVERSION_TRITONMUSAGPU_TO_LLVM_ALLOCATION_H +#define TRITON_CONVERSION_TRITONMUSAGPU_TO_LLVM_ALLOCATION_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" + +#include + +namespace mlir { +namespace triton { +class TargetInfoBase; + +namespace musa_gpu { +bool needsMusaRepDisjointGenericScratch(RankedTensorType srcTy, + RankedTensorType dstTy, + const TargetInfoBase &targetInfo); + +std::function +getMusaAllocationAnalysisScratchSizeFn(const TargetInfoBase &targetInfo); + +} // namespace musa_gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONMUSAGPU_TO_LLVM_ALLOCATION_H diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/CMakeLists.txt b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..07caed7aaf --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonMUSAGPUToLLVM) +add_public_tablegen_target(TritonMUSAGPUConversionPassIncGen) diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.h b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.h new file mode 100644 index 0000000000..81bb6cd483 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.h @@ -0,0 +1,37 @@ +#ifndef TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_PASSES_H +#define TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_PASSES_H + +#include "Dialect/MUSA/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "musa/include/TritonMUSAGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertTritonMUSAGPUToLLVMPass(); +std::unique_ptr> +createConvertTritonMUSAGPUToLLVMPass(int32_t computeCapability); +std::unique_ptr> +createAllocateMUSASharedMemoryPass(int32_t computeCapability); + +#define GEN_PASS_REGISTRATION +#include "musa/include/TritonMUSAGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.td b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.td new file mode 100644 index 0000000000..369faea136 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Passes.td @@ -0,0 +1,37 @@ +#ifndef TRITONMUSAGPU_CONVERSION_PASSES +#define TRITONMUSAGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonMUSAGPUToLLVM : Pass<"convert-triton-musagpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonGPU to LLVM for MUSA"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"31", + "device compute capability">, + ]; +} + +def AllocateMUSASharedMemory : Pass<"allocate-musa-shared-memory", "mlir::ModuleOp"> { + let summary = "Allocate shared memory for MUSA"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"31", + "device compute capability">, + ]; +} + +#endif diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/TargetInfo.h b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/TargetInfo.h new file mode 100644 index 0000000000..5025e539c0 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/TargetInfo.h @@ -0,0 +1,73 @@ +#ifndef TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_TARGETINFO_H +#define TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_TARGETINFO_H + +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include + +namespace mlir::triton::MUSA { + +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + explicit TargetInfo(int computeCapability) + : computeCapability(computeCapability) {} + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const override; + + void barrier(Location loc, RewriterBase &rewriter, + triton::gpu::AddrSpace targets) const override; + void warpSync(Location loc, RewriterBase &rewriter) const override; + + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, Value pred, + Operation *localLoadOp = nullptr) const override; + + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const override; + + Value permute(RewriterBase &rewriter, Location loc, Value a, Value b, + Value selector) const override; + + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + ProgramIDDim axis) const override; + + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args, + ArrayRef isSigned = {}) const override; + + void printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + ArrayRef isSigned = {}) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + + int getSharedAddressSpace() const override; + int getAddressSpace(Attribute addressSpace) const override; + bool supportVectorizedAtomics() const override; + +private: + int computeCapability; +}; + +} // namespace mlir::triton::MUSA + +#endif diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Utility.h b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Utility.h new file mode 100644 index 0000000000..7c2bc360e7 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUToLLVM/Utility.h @@ -0,0 +1,83 @@ +#ifndef TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_UTILITY_H +#define TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_UTILITY_H + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { +namespace LLVM { +namespace MUSA { + +inline constexpr char Predicated_Load[] = "__predicated_load"; +inline constexpr char Predicated_InplaceLoad[] = "__predicated_inplace_load"; +inline constexpr char Predicated_Store[] = "__predicated_store"; + +struct SqmmaAccumulatorCarrierInfo { + RankedTensorType tensorType; + unsigned fragmentCount; + unsigned fragmentElems; + Type fragmentType; + Type carrierType; +}; + +FailureOr +getSqmmaAccumulatorCarrierInfo(Type type); + +SmallVector unpackSqmmaAccumulatorCarrier(Location loc, Value carrier, + Type type, + RewriterBase &rewriter); +Value packSqmmaAccumulatorCarrier(Location loc, ValueRange fragments, Type type, + RewriterBase &rewriter); +Value carrierFragmentToMathVec(Location loc, Value fragment, Type type, + RewriterBase &rewriter); +Value mathVecToCarrierFragment(Location loc, Value mathVec, Type type, + RewriterBase &rewriter); +Value packSqmmaAccumulatorCarrierFromTensor(Location loc, Value tensorValue, + RankedTensorType tensorType, + const LLVMTypeConverter *converter, + RewriterBase &rewriter); +Value unpackSqmmaAccumulatorCarrierToTensor(Location loc, Value carrier, + RankedTensorType tensorType, + const LLVMTypeConverter *converter, + RewriterBase &rewriter); + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + unsigned width); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + triton::ProgramIDDim axis); + +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal); + +Value llInplaceLoad(RewriterBase &rewriter, Location loc, Value ptr, + Type elemTy, Value pred, Value falseVal); + +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred); + +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); + +/// Create a predicate with just single active thread. +Value createElectPredicate(Location loc, PatternRewriter &rewriter); + +LLVM::LLVMFuncOp getLibdeviceFuncCall(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type retType, + ValueRange ins = {}); + +} // namespace MUSA +} // namespace LLVM +} // namespace mlir + +#endif diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt new file mode 100644 index 0000000000..28184e982d --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonMUSAGPU) +add_public_tablegen_target(TritonMUSAGPUTransformsIncGen) diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.h b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.h new file mode 100644 index 0000000000..f71f0fd88d --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.h @@ -0,0 +1,20 @@ +#ifndef TRITON_THIRD_PARTY_MUSA_INCLUDE_TRITONMUSAGPUTRANSFORMS_PASSES_H_ +#define TRITON_THIRD_PARTY_MUSA_INCLUDE_TRITONMUSAGPUTRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +} // namespace mlir + +namespace mlir { +// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "TritonMUSAGPUTransforms/Passes.h.inc" +} // namespace mlir + +#endif // TRITON_THIRD_PARTY_MUSA_INCLUDE_TRITONMUSAGPUTRANSFORMS_PASSES_H_ diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td new file mode 100644 index 0000000000..076ed5dff0 --- /dev/null +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td @@ -0,0 +1,193 @@ +#ifndef TRITONMUSAGPU_PASSES +#define TRITONMUSAGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonMUSAGPUPipeline + : Pass<"tritonmusa-pipeline", "mlir::ModuleOp"> { + let summary = "MUSA backend-local software pipeline"; + let description = [{ + Run the MUSA backend-local software pipeliner: + lower and expand scheduled loops, classify SQMMA accumulator pipelines + under MUSA-specific async/wait semantics, schedule generic async waits, + and pipeline TMA stores without routing through the public NVIDIA-only + post-expand hooks. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; + let options = [ + Option<"numStages", "num-stages", "int32_t", /*default*/"3", + "number of pipeline stages">, + Option<"dumpIntermediateSteps", "dump-intermediate-steps", + "bool", /*default*/"false", + "dump intermediate pipeline steps"> + ]; +} + +def TritonMUSAGPUAccelerateMatmul + : Pass<"tritonmusa-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + let description = [{ + Optimize the input/output layout of `dot` instructions to make them + compatible with MUSA WMMA/SQMMA instructions. + }]; + let dependentDialects = []; +} + +def TritonMUSAGPUOptimizeDotOperands + : Pass<"tritonmusa-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "Normalize MUSA dot operands and descriptor-fed landing views"; + let description = [{ + Apply the subset of dot-operand shared-memory canonicalizations that are + valid on MUSA, and normalize descriptor-fed tensor transposes/reshapes into + memdesc view ops so descriptor landing stays canonical before descriptor + encoding and TME/pipeline lowering. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUTMELowering + : Pass<"tritonmusa-tme-lowering", "mlir::ModuleOp"> { + let summary = "Lower descriptor load/store to MUSA TME async copy ops"; + let description = [{ + Rewrite `tt.descriptor_load/store` to MUSA async TME operations plus + barrier/store synchronization ops. This pass only handles descriptor ops + and keeps descriptor semantics explicit in TTGIR. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUIssueBarrierInsertion + : Pass<"tritonmusa-issue-barrier-insertion", "mlir::ModuleOp"> { + let summary = "Insert MUSA TTGIR issue barriers before selected ops"; + let description = [{ + Materialize shared-memory producer-consumer barriers in TTGIR before + MUSA TME store issue and non-dual-TME SQMMA issue sites. The pass keeps + barrier placement explicit so LLVM lowering can remain instruction- + selection focused. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUFinalizeBarriers + : Pass<"tritonmusa-finalize-barriers", "mlir::ModuleOp"> { + let summary = "Finalize MUSA async barrier resources"; + let description = [{ + Canonicalize barrier resource reservations into a single entry block + `ttmg.bar_record` per function, while preserving the highest reserved + barrier id and implicit async-commit-group floor semantics. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUOptimizeSqmmaAccumulatorLayout + : Pass<"tritonmusa-optimize-sqmma-accumulator-layout", "mlir::ModuleOp"> { + let summary = "Sink post-TME SQMMA accumulator layout conversions"; + let description = [{ + Canonicalize loop-carried SQMMA accumulator convert_layout chains that are + introduced by descriptor/TME lowering so accumulators stay in MMA layout + through the loop when legal. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUCanonicalizeSqmmaResultConversions + : Pass<"tritonmusa-canonicalize-sqmma-result-conversions", "mlir::ModuleOp"> { + let summary = "Canonicalize SQMMA result convert/trunc chains"; + let description = [{ + Push floating-point truncation through convert_layout on SQMMA results so + result-side layout cleanup stays outside of lowering passes. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUConvertSqmmaToMTGPU + : Pass<"tritonmusa-convert-sqmma-to-mtgpu", "mlir::ModuleOp"> { + let summary = "Convert loop-carried SQMMA accumulator chains to MTGPU"; + let description = [{ + Rewrite guarded SQMMA hot loops from tensor-carried ttmg.squad_dot chains + to backend-private mtgpu.sqmma chains so compact accumulator carriers stay + alive across scf.for boundaries without leaking into public interfaces. + }]; + let dependentDialects = [ + "mlir::scf::SCFDialect", + "mlir::triton::musa::MUSADialect", + "mlir::triton::mtgpu::MTGPUDialect" + ]; +} + +def TritonMUSAGPUOptimizeAccumulatorInit + : Pass<"tritonmusa-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Optimize MUSA dot accumulator zero-init to first-use useC"; + let description = [{ + Convert explicit zero-initialized MUSA dot accumulators into first-use + useC flags for loop-carried SQMMA/WMMA patterns when this is legal under + current MUSA lowering semantics. + }]; + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUOptimizeDescriptorEncoding + : Pass<"tritonmusa-optimize-descriptor-encoding", "mlir::ModuleOp"> { + let summary = "Optimize descriptor encodings for MUSA TTGIR pipelines"; + let description = [{ + Canonicalize descriptor encodings before MUSA TME lowering. + This pass is intentionally backend-local so MUSA pipelines do not depend + on NVIDIA transform pass entry points. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::musa::MUSADialect" + ]; +} + +def TritonMUSAGPUMarkInplaceLoads + : Pass<"tritonmusa-mark-inplace-loads", "mlir::ModuleOp"> { + let summary = "Mark TTGIR loads that feed same-address stores"; + let description = [{ + Mark Triton load operations whose pointer expression is structurally + equivalent to a store pointer in the same function. This pass is used by + the PH1 backend-local inplace load optimization and intentionally stays in + the transform pipeline so LLVM lowering remains selection focused. + }]; + let dependentDialects = [ + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +#endif // TRITONMUSAGPU_PASSES diff --git a/third_party/mthreads/musa/lib/CMakeLists.txt b/third_party/mthreads/musa/lib/CMakeLists.txt new file mode 100644 index 0000000000..fa7e5fb35e --- /dev/null +++ b/third_party/mthreads/musa/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Dialect) +add_subdirectory(MTGPUToLLVM) +add_subdirectory(TritonMUSAGPUToLLVM) +add_subdirectory(TritonMUSAGPUTransforms) diff --git a/third_party/mthreads/musa/lib/Dialect/CMakeLists.txt b/third_party/mthreads/musa/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..4e7e1ea71d --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(MTGPU) +add_subdirectory(MUSA) diff --git a/third_party/mthreads/musa/lib/Dialect/MTGPU/CMakeLists.txt b/third_party/mthreads/musa/lib/Dialect/MTGPU/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MTGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/CMakeLists.txt b/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..f2f4c20b5e --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(MTGPUIR + Dialect.cpp + + DEPENDS + MTGPUTableGen + MTGPUTypesIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect +) diff --git a/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/Dialect.cpp b/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/Dialect.cpp new file mode 100644 index 0000000000..334529649e --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MTGPU/IR/Dialect.cpp @@ -0,0 +1,287 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +// clang-format off +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MTGPU/IR/Dialect.cpp.inc" +// clang-format on + +#include + +using namespace mlir; +using namespace mlir::triton::mtgpu; +namespace ttg = mlir::triton::gpu; + +void MTGPUDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "Dialect/MTGPU/IR/MTGPUTypes.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/MTGPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_TYPEDEF_CLASSES +#include "Dialect/MTGPU/IR/MTGPUTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "Dialect/MTGPU/IR/Ops.cpp.inc" +#include "Dialect/MTGPU/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::mtgpu { + +LogicalResult +SqmmaAccumulatorType::verify(function_ref emitError, + Type tensorType) { + auto rankedTy = dyn_cast(tensorType); + if (!rankedTy) + return emitError() << "expected ranked tensor accumulator type"; + if (!isa_and_nonnull( + rankedTy.getEncoding())) { + return emitError() << "expected tensor encoded with #ttg.musa_sqmma"; + } + Type elemTy = rankedTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isInteger(32)) { + return emitError() + << "expected f32 or i32 SQMMA accumulator element type, got " + << elemTy; + } + return success(); +} + +static LogicalResult verifyDotShapeContract(Operation *op, + ArrayRef aShape, + ArrayRef bShape, + ArrayRef cShape, + ArrayRef dShape) { + if (aShape.size() != 2 && aShape.size() != 3) + return op->emitError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size() || + cShape.size() != dShape.size()) + return op->emitError( + "expected all operands and result to have the same rank"); + + if (aShape[aShape.size() - 1] != bShape[bShape.size() - 2]) { + return op->emitError("expected the last dimension of the first operand " + "to equal the second-to-last dimension of the " + "second operand"); + } + + if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0] || + cShape[0] != dShape[0])) { + return op->emitError("expected batch dimensions to match"); + } + + if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] || + cShape[cShape.size() - 1] != bShape[bShape.size() - 1]) { + return op->emitError( + "expected accumulator shape to match dot output shape"); + } + if (cShape != dShape) + return op->emitError("expected result shape to match accumulator shape"); + return success(); +} + +static bool isFP8Type(Type type) { + return llvm::isa(type); +} + +LogicalResult SqmmaOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + (void)context; + (void)attributes; + (void)properties; + (void)regions; + + auto accTy = dyn_cast(operands[2].getType()); + if (!accTy) + return failure(); + inferredReturnTypes.push_back(accTy); + + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getAccumulatorType().getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult SqmmaOp::verify() { + auto aTy = cast(getA().getType()); + auto bTy = cast(getB().getType()); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + + auto accTy = dyn_cast(getC().getType()); + auto retTy = dyn_cast(getD().getType()); + if (!accTy || !retTy) + return emitError( + "SQMMA accumulator/result must use !mtgpu.sqmma_accumulator"); + auto accTensorTy = accTy.getAccumulatorType(); + auto retTensorTy = retTy.getAccumulatorType(); + auto retEnc = accTensorTy.getEncoding(); + Dialect &dialect = aEncoding.getDialect(); + auto interface = dyn_cast(&dialect); + if (!interface) + return emitError( + "SQMMA operand encoding dialect does not implement layout inference"); + if (interface->inferDotOpEncoding(aEncoding, 0, retEnc, getLoc()).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEncoding, 1, retEnc, getLoc()).failed()) + return failure(); + if (retTensorTy.getEncoding() != retEnc) + return emitError( + "SQMMA result carrier must use the same encoding as the accumulator"); + if (failed(verifyDotShapeContract(getOperation(), aTy.getShape(), + bTy.getShape(), accTensorTy.getShape(), + retTensorTy.getShape()))) + return failure(); + + auto accMode = getAccMode(); + int64_t maxNumImpreciseAcc = std::max(0, getMaxNumImpreciseAcc()); + bool fp8ToF32 = + isFP8Type(aTy.getElementType()) && retTensorTy.getElementType().isF32(); + if (usesHardwareAccumulator()) { + if (maxNumImpreciseAcc != 0) + return emitError( + "hardware SQMMA accumulation requires maxNumImpreciseAcc == 0"); + } else if (accMode == SQMMAAccumulationMode::partial) { + if (!fp8ToF32) + return emitError("partial SQMMA accumulation currently requires fp8 " + "inputs and f32 accumulators"); + if (maxNumImpreciseAcc <= 0) + return emitError( + "partial SQMMA accumulation requires maxNumImpreciseAcc > 0"); + if (maxNumImpreciseAcc > aTy.getShape().back()) + return emitError( + "partial SQMMA accumulation requires maxNumImpreciseAcc <= K"); + } else if (accMode == SQMMAAccumulationMode::software) { + if (!fp8ToF32) + return emitError("software SQMMA accumulation currently requires fp8 " + "inputs and f32 accumulators"); + if (maxNumImpreciseAcc != 0) + return emitError( + "software SQMMA accumulation requires maxNumImpreciseAcc == 0"); + } + return success(); +} + +void SqmmaOp::getEffects( + SmallVectorImpl> + &effects) { + auto &a = getAMutable(); + auto &b = getBMutable(); + if (isa(a.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &a, + ttg::SharedMemory::get()); + if (isa(b.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &b, + ttg::SharedMemory::get()); +} + +bool SqmmaOp::needsPartialAccumulator() { + return getAccMode() == SQMMAAccumulationMode::partial; +} + +bool SqmmaOp::usesSoftwareAccumulator() { + return getAccMode() == SQMMAAccumulationMode::software; +} + +bool SqmmaOp::usesHardwareAccumulator() { + return getAccMode() == SQMMAAccumulationMode::hardware; +} + +LogicalResult SqmmaWaitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + (void)context; + (void)location; + (void)attributes; + (void)properties; + (void)regions; + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return success(); +} + +LogicalResult SqmmaWaitOp::verify() { + if (getInputs().empty()) + return emitOpError("expected to be waiting on at least one dependency"); + return success(); +} + +LogicalResult SqmmaWaitOp::canonicalize(SqmmaWaitOp op, + PatternRewriter &rewriter) { + SmallVector liveInputs; + SmallVector liveResultIdxs; + liveInputs.reserve(op.getNumResults()); + liveResultIdxs.reserve(op.getNumResults()); + + for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + bool keepForMemDesc = isa(op.getInputs()[idx].getType()); + if (!keepForMemDesc && op.getResult(idx).use_empty()) + continue; + liveInputs.push_back(op.getInputs()[idx]); + liveResultIdxs.push_back(idx); + } + + if (liveResultIdxs.size() == op.getNumResults()) + return failure(); + + if (liveInputs.empty()) { + rewriter.eraseOp(op); + return success(); + } + + rewriter.setInsertionPoint(op); + auto newWait = SqmmaWaitOp::create(rewriter, op.getLoc(), liveInputs); + newWait->setAttrs(op->getAttrs()); + for (unsigned newIdx = 0; newIdx < liveResultIdxs.size(); ++newIdx) + rewriter.replaceAllUsesWith(op.getResult(liveResultIdxs[newIdx]), + newWait.getResult(newIdx)); + rewriter.eraseOp(op); + return success(); +} + +LogicalResult PackSqmmaAccumulatorOp::verify() { + auto carrierTy = dyn_cast(getCarrier().getType()); + if (!carrierTy) + return emitError("result must be !mtgpu.sqmma_accumulator"); + if (carrierTy.getAccumulatorType() != getInput().getType()) + return emitError("carrier tensor type must match input type"); + return success(); +} + +LogicalResult UnpackSqmmaAccumulatorOp::verify() { + auto carrierTy = cast(getCarrier().getType()); + if (carrierTy.getAccumulatorType() != getOutput().getType()) + return emitError("result tensor type must match carrier tensor type"); + return success(); +} + +} // namespace mlir::triton::mtgpu diff --git a/third_party/mthreads/musa/lib/Dialect/MUSA/CMakeLists.txt b/third_party/mthreads/musa/lib/Dialect/MUSA/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MUSA/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/musa/lib/Dialect/MUSA/IR/CMakeLists.txt b/third_party/mthreads/musa/lib/Dialect/MUSA/IR/CMakeLists.txt new file mode 100644 index 0000000000..97799dcaa5 --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MUSA/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(MUSAIR + Dialect.cpp + + DEPENDS + MUSATableGen + MUSAAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect +) diff --git a/third_party/mthreads/musa/lib/Dialect/MUSA/IR/Dialect.cpp b/third_party/mthreads/musa/lib/Dialect/MUSA/IR/Dialect.cpp new file mode 100644 index 0000000000..e506246655 --- /dev/null +++ b/third_party/mthreads/musa/lib/Dialect/MUSA/IR/Dialect.cpp @@ -0,0 +1,577 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "TritonMUSACommon/BarrierUtils.h" +#include "TritonMUSACommon/MMAContractUtils.h" +#include "TritonMUSACommon/MMAEncodingUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// clang-format off +#include "Dialect/MUSA/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.cpp.inc" +// clang-format on + +#include + +using namespace mlir; +using namespace mlir::triton::musa; +namespace ttg = mlir::triton::gpu; + +void MUSADialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/MUSA/IR/MUSAAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/MUSA/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "Dialect/MUSA/IR/Ops.cpp.inc" +#include "Dialect/MUSA/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::musa { + +static LogicalResult verifyAsyncBarrierId(Operation *op, Value barId, + StringRef operandName) { + APInt constant; + if (!matchPattern(barId, m_ConstantInt(&constant))) + return success(); + + int64_t value = constant.getSExtValue(); + if (value == 0) { + return op->emitOpError() + << operandName << " 0 is reserved for CTA barrier on MUSA"; + } + if (value < 0 || value > kMaxBarrierId) { + return op->emitOpError() << operandName << " must be in [1, " + << kMaxBarrierId << "] when constant"; + } + return success(); +} + +static LogicalResult verifyNonNegativeI32Constant(Operation *op, Value value, + StringRef operandName) { + APInt constant; + if (!matchPattern(value, m_ConstantInt(&constant))) + return success(); + if (constant.isNegative()) { + return op->emitOpError() + << operandName << " must be non-negative when constant"; + } + return success(); +} + +static LogicalResult verifyDotShapeContract(Operation *op, + ArrayRef aShape, + ArrayRef bShape, + ArrayRef cShape, + ArrayRef dShape) { + if (aShape.size() != 2 && aShape.size() != 3) + return op->emitError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size() || + cShape.size() != dShape.size()) + return op->emitError( + "expected all operands and result to have the same rank"); + + if (aShape[aShape.size() - 1] != bShape[bShape.size() - 2]) { + return op->emitError("expected the last dimension of the first operand " + "to equal the second-to-last dimension of the " + "second operand"); + } + + if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0] || + cShape[0] != dShape[0])) { + return op->emitError("expected batch dimensions to match"); + } + + if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] || + cShape[cShape.size() - 1] != bShape[bShape.size() - 1]) { + return op->emitError( + "expected accumulator shape to match dot output shape"); + } + if (cShape != dShape) + return op->emitError("expected result shape to match accumulator shape"); + return success(); +} + +static bool isFP8Type(Type type) { + return llvm::isa(type); +} + +static std::optional getWmmaEltTypeForVerify(Type type) { + if (type.isInteger(8)) + return SQMMAEltType::s8; + if (type.isF16()) + return SQMMAEltType::f16; + if (type.isBF16()) + return SQMMAEltType::bf16; + if (type.isF32()) + return SQMMAEltType::tf32; + if (llvm::isa(type)) + return SQMMAEltType::e4m3; + if (llvm::isa(type)) + return SQMMAEltType::e5m2; + return std::nullopt; +} + +static std::optional getSqmmaAccumEltTypeForVerify(Type type) { + if (type.isF32()) + return SQMMAEltType::f32; + if (type.isInteger(32)) + return SQMMAEltType::s32; + if (type.isF16()) + return SQMMAEltType::f16; + return std::nullopt; +} + +static LogicalResult verifyNoAcceleratedFP16Accumulator(Operation *op, + Type accElemTy, + Type retElemTy, + StringRef opName) { + if (!accElemTy.isF16() && !retElemTy.isF16()) + return success(); + return op->emitOpError() + << opName + << " fp16 accumulators/results are not currently supported; use an " + "fp32 carrier and truncate after layout conversion instead"; +} + +LogicalResult SquadDotOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult SquadDotOp::verify() { + auto aTy = cast(getA().getType()); + auto bTy = cast(getB().getType()); + if (aTy.getElementType() != bTy.getElementType()) + return emitError("SQMMA operands A and B must use the same element type"); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + + auto accTy = cast(getC().getType()); + auto retTy = cast(getD().getType()); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + auto mmaEnc = dyn_cast(retEnc); + if (!mmaEnc) + return emitError("SQMMA result layout must be #ttg.musa_sqmma"); + if (!supportsMusaSqmmaEncoding(mmaEnc)) + return emitError( + "SQMMA result encoding uses unsupported MUSA SQMMA version"); + auto instrShape = mmaEnc.getInstrShape(); + if (instrShape.size() != 3) + return emitError("SQMMA result encoding must carry a 3D instrShape"); + if (getM() != static_cast(instrShape[0]) || + getN() != static_cast(instrShape[1]) || + getK() != static_cast(instrShape[2])) + return emitError( + "SQMMA m/n/k attrs must match the #ttg.musa_sqmma instrShape"); + if (retTy.getEncoding() != retEnc) + return emitError( + "SQMMA result shape and accumulator must use the same encoding"); + auto aEltType = getWmmaEltTypeForVerify(aTy.getElementType()); + auto bEltType = getWmmaEltTypeForVerify(bTy.getElementType()); + auto cEltType = getSqmmaAccumEltTypeForVerify(retTy.getElementType()); + if (!aEltType || !bEltType || !cEltType) + return emitError("SQMMA operands/results must use supported " + "int8/f16/bf16/tf32/fp8/f32/s32 element types"); + if (*aEltType != getEltTypeA() || *bEltType != getEltTypeB() || + *cEltType != getEltTypeC()) + return emitError( + "SQMMA eltType attrs must match operand/result element types"); + if ((aTy.getElementType().isF32() || bTy.getElementType().isF32()) && + getInputPrecision() != static_cast(triton::InputPrecision::TF32)) + return emitError("SQMMA f32 operands require TF32 input precision"); + if (!triton::musa::isSupportedSqmma(getEltTypeA(), getEltTypeB(), + getEltTypeC(), getM(), getN(), getK())) + return emitError( + "SQMMA encoding carries an unsupported PH1 shape/type combination"); + Dialect &dialect = aEncoding.getDialect(); + auto interface = dyn_cast(&dialect); + if (!interface) + return emitError( + "SQMMA operand encoding dialect does not implement layout inference"); + if (interface->inferDotOpEncoding(aEncoding, 0, retEnc, getLoc()).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEncoding, 1, retEnc, getLoc()).failed()) + return failure(); + if (failed(verifyDotShapeContract(getOperation(), aTy.getShape(), + bTy.getShape(), accTy.getShape(), + retTy.getShape()))) + return failure(); + auto accMode = getAccMode(); + int64_t maxNumImpreciseAcc = std::max(0, getMaxNumImpreciseAcc()); + bool fp8ToF32 = + isFP8Type(aTy.getElementType()) && retTy.getElementType().isF32(); + if (usesHardwareAccumulator()) { + if (failed(verifyNoAcceleratedFP16Accumulator( + getOperation(), accTy.getElementType(), retTy.getElementType(), + "SQMMA"))) + return failure(); + if (maxNumImpreciseAcc != 0) + return emitError( + "hardware SQMMA accumulation requires maxNumImpreciseAcc == 0"); + } else if (accMode == SQMMAAccumulationMode::partial) { + if (!fp8ToF32) + return emitError("partial SQMMA accumulation currently requires fp8 " + "inputs and f32 accumulators"); + if (maxNumImpreciseAcc <= 0) + return emitError( + "partial SQMMA accumulation requires maxNumImpreciseAcc > 0"); + if (maxNumImpreciseAcc > aTy.getShape().back()) + return emitError( + "partial SQMMA accumulation requires maxNumImpreciseAcc <= K"); + } else if (accMode == SQMMAAccumulationMode::software) { + if (!fp8ToF32) + return emitError("software SQMMA accumulation currently requires fp8 " + "inputs and f32 accumulators"); + if (maxNumImpreciseAcc != 0) + return emitError( + "software SQMMA accumulation requires maxNumImpreciseAcc == 0"); + } + return success(); +} + +bool SquadDotOp::verifyDims() { + auto aShape = cast(getA().getType()).getShape(); + auto bShape = cast(getB().getType()).getShape(); + return aShape[aShape.size() - 1] == bShape[bShape.size() - 2]; +} + +void SquadDotOp::getEffects( + SmallVectorImpl> + &effects) { + auto &a = getAMutable(); + auto &b = getBMutable(); + if (isa(a.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &a, + ttg::SharedMemory::get()); + if (isa(b.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &b, + ttg::SharedMemory::get()); +} + +bool SquadDotOp::needsPartialAccumulator() { + return getAccMode() == SQMMAAccumulationMode::partial; +} + +bool SquadDotOp::usesSoftwareAccumulator() { + return getAccMode() == SQMMAAccumulationMode::software; +} + +bool SquadDotOp::usesHardwareAccumulator() { + return getAccMode() == SQMMAAccumulationMode::hardware; +} + +LogicalResult SquadDotWaitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + (void)context; + (void)location; + (void)attributes; + (void)properties; + (void)regions; + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return success(); +} + +LogicalResult SquadDotWaitOp::verify() { + if (getInputs().empty()) + return emitOpError("expected to be waiting on at least one dependency"); + return success(); +} + +LogicalResult SquadDotWaitOp::canonicalize(SquadDotWaitOp op, + PatternRewriter &rewriter) { + SmallVector liveInputs; + SmallVector liveResultIdxs; + liveInputs.reserve(op.getNumResults()); + liveResultIdxs.reserve(op.getNumResults()); + + for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + bool keepForMemDesc = isa(op.getInputs()[idx].getType()); + if (!keepForMemDesc && op.getResult(idx).use_empty()) + continue; + liveInputs.push_back(op.getInputs()[idx]); + liveResultIdxs.push_back(idx); + } + + if (liveResultIdxs.size() == op.getNumResults()) + return failure(); + + if (liveInputs.empty()) { + rewriter.eraseOp(op); + return success(); + } + + rewriter.setInsertionPoint(op); + auto newWait = SquadDotWaitOp::create(rewriter, op.getLoc(), liveInputs); + newWait->setAttrs(op->getAttrs()); + for (unsigned newIdx = 0; newIdx < liveResultIdxs.size(); ++newIdx) + rewriter.replaceAllUsesWith(op.getResult(liveResultIdxs[newIdx]), + newWait.getResult(newIdx)); + rewriter.eraseOp(op); + return success(); +} + +LogicalResult WmmaDotOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult WmmaDotOp::verify() { + auto aTy = cast(getA().getType()); + auto bTy = cast(getB().getType()); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + if ((aTy.getRank() != 2 && aTy.getRank() != 3) || + aTy.getRank() != bTy.getRank()) + return emitError( + "WMMA operands must be rank-2 or rank-3 tensors with matching rank"); + + auto dTy = cast(getD().getType()); + auto mmaEnc = dyn_cast(dTy.getEncoding()); + if (!mmaEnc) + return emitError("WMMA result layout must be #ttg.musa_wmma"); + if (!supportsMusaWmmaEncoding(mmaEnc)) + return emitError("WMMA result encoding uses unsupported MUSA WMMA version"); + auto instrShape = mmaEnc.getInstrShape(); + if (instrShape.size() != 3) + return emitError("WMMA result encoding must carry a 3D instrShape"); + if (std::max(0, getMaxNumImpreciseAcc()) != 0) + return emitError("WMMA maxNumImpreciseAcc must be 0 until partial " + "accumulation is implemented"); + if (getM() != static_cast(instrShape[0]) || + getN() != static_cast(instrShape[1]) || + getK() != static_cast(instrShape[2])) + return emitError( + "WMMA m/n/k attrs must match the #ttg.musa_wmma instrShape"); + + auto aEltType = getWmmaEltTypeForVerify(aTy.getElementType()); + auto bEltType = getWmmaEltTypeForVerify(bTy.getElementType()); + if (!aEltType || !bEltType) + return emitError("WMMA operands must use supported int8/f16/bf16/tf32/fp8 " + "element types"); + if (*aEltType != getEltTypeA() || *bEltType != getEltTypeB()) + return emitError("WMMA eltType attrs must match operand element types"); + if ((aTy.getElementType().isF32() || bTy.getElementType().isF32()) && + getInputPrecision() != static_cast(triton::InputPrecision::TF32)) + return emitError("WMMA f32 operands require TF32 input precision"); + if (!triton::musa::lookupWmmaIntrinsic(aTy.getElementType(), instrShape)) + return emitError( + "WMMA encoding carries an unsupported shape/type combination"); + + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + auto aDotEncoding = dyn_cast_or_null(aEncoding); + auto bDotEncoding = dyn_cast_or_null(bEncoding); + if (!aDotEncoding || !bDotEncoding) + return emitError("WMMA operands A and B must use DotOperandEncodingAttr"); + if (aDotEncoding.getOpIdx() != 0 || bDotEncoding.getOpIdx() != 1) + return emitError("WMMA operands A/B must use dot operand indices 0/1"); + if (aDotEncoding.getParent() != dTy.getEncoding() || + bDotEncoding.getParent() != dTy.getEncoding()) + return emitError("WMMA operand dot layouts must point to the same " + "#ttg.musa_wmma result encoding"); + + auto cTy = cast(getC().getType()); + if (failed(verifyDotShapeContract(getOperation(), aTy.getShape(), + bTy.getShape(), cTy.getShape(), + dTy.getShape()))) + return failure(); + if (failed(verifyNoAcceleratedFP16Accumulator( + getOperation(), cTy.getElementType(), dTy.getElementType(), "WMMA"))) + return failure(); + + return success(); +} + +bool WmmaDotOp::verifyDims() { + auto aShape = cast(getA().getType()).getShape(); + auto bShape = cast(getB().getType()).getShape(); + return aShape[aShape.size() - 1] == bShape[bShape.size() - 2]; +} + +bool WmmaDotOp::needsPartialAccumulator() { return false; } + +LogicalResult BarRecordOp::verify() { + APInt barId; + if (!matchPattern(getBarId(), m_ConstantInt(&barId))) + return emitOpError("bar_record expects a constant max barrier id"); + if (barId.getSExtValue() <= 0 || barId.getSExtValue() > kMaxBarrierId) { + return emitOpError("bar_record must be in [1, ") << kMaxBarrierId << "]"; + } + return success(); +} + +LogicalResult InitArrivalOp::verify() { + if (failed(verifyAsyncBarrierId(getOperation(), getBarId(), "barId"))) + return failure(); + if (failed(verifyNonNegativeI32Constant(getOperation(), getArriveCount(), + "arriveCount"))) + return failure(); + return verifyNonNegativeI32Constant(getOperation(), getPhaseId(), "phaseId"); +} + +LogicalResult BarrierAddTransOp::verify() { + if (failed(verifyAsyncBarrierId(getOperation(), getBarId(), "barId"))) + return failure(); + return verifyNonNegativeI32Constant(getOperation(), getTransBytes(), + "transBytes"); +} + +LogicalResult ArriveBarrierOp::verify() { + return verifyAsyncBarrierId(getOperation(), getBarId(), "barId"); +} + +LogicalResult ArriveBarrierNoRetOp::verify() { + return verifyAsyncBarrierId(getOperation(), getBarId(), "barId"); +} + +LogicalResult WaitBarrierOp::verify() { + if (failed(verifyAsyncBarrierId(getOperation(), getBarId(), "barId"))) + return failure(); + return verifyNonNegativeI32Constant(getOperation(), getPhaseId(), "phaseId"); +} + +void WaitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getBarIdMutable(), + ttg::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +static LogicalResult verifyTMECopyShapeContract(Operation *op, ValueRange coord, + ArrayRef blockShape) { + if (coord.empty() || coord.size() > 5 || coord.size() != blockShape.size()) + return op->emitOpError( + "expects coord/blockShape rank in [1, 5] and matching"); + for (int32_t dim : blockShape) { + if (dim <= 0) + return op->emitOpError("expects positive blockShape dimensions"); + } + return success(); +} + +template +static LogicalResult verifyTMESwizzleContract(OpTy op) { + auto sgAttr = op->template getAttrOfType( + "swizzleGranularity"); + auto ssAttr = + op->template getAttrOfType("swizzleStride"); + auto slAttr = op->template getAttrOfType("swizzleLine"); + if (!sgAttr || !ssAttr || !slAttr) + return op->emitOpError("requires typed TME swizzle attrs"); + + TMESwizzleGranularity granularity = sgAttr.getValue(); + TMESwizzleStride stride = ssAttr.getValue(); + TMESwizzleLine line = slAttr.getValue(); + if (isValidTMESwizzleConfig(granularity, stride, line)) + return success(); + + return op->emitOpError() + << "expects a valid TME swizzle config with granularity <= stride <= " + "line, but got granularity=" + << getSwizzleGranularityBytes(granularity) + << "B, stride=" << getSwizzleStrideBytes(stride) + << "B, line=" << getSwizzleLineBytes(line) << "B"; +} + +LogicalResult AsyncTMECopyGlobalToLocalOp::verify() { + if (failed(verifyAsyncBarrierId(getOperation(), getBarId(), "barId"))) + return failure(); + if (failed(verifyTMECopyShapeContract(getOperation(), getCoord(), + getBlockShape()))) + return failure(); + return verifyTMESwizzleContract(*this); +} + +LogicalResult AsyncTMECopyLocalToGlobalOp::verify() { + if (failed(verifyTMECopyShapeContract(getOperation(), getCoord(), + getBlockShape()))) + return failure(); + return verifyTMESwizzleContract(*this); +} + +void AsyncTMECopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + auto &desc = getOperation()->getOpOperand(0); + auto &result = getOperation()->getOpOperand(getCoord().size() + 2); + effects.emplace_back(MemoryEffects::Read::get(), &desc, + ::mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &result, + ttg::SharedMemory::get()); +} + +void AsyncTMECopyLocalToGlobalOp::getEffects( + SmallVectorImpl> + &effects) { + auto &desc = getOperation()->getOpOperand(0); + auto &src = getOperation()->getOpOperand(getCoord().size() + 1); + effects.emplace_back(MemoryEffects::Read::get(), &src, + ttg::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &desc, + ::mlir::triton::GlobalMemory::get()); +} + +} // namespace mlir::triton::musa diff --git a/third_party/mthreads/musa/lib/MTGPUToLLVM/CMakeLists.txt b/third_party/mthreads/musa/lib/MTGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..5a2415d0e6 --- /dev/null +++ b/third_party/mthreads/musa/lib/MTGPUToLLVM/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(MTGPUToLLVM + MTGPUToLLVMPass.cpp + + DEPENDS + MTGPUConversionPassIncGen + MTGPUTableGen + MTGPUTypesIncGen + + LINK_LIBS PUBLIC + TritonMUSAGPUToLLVM + MTGPUIR + MLIRControlFlowDialect + MLIRFuncTransforms +) diff --git a/third_party/mthreads/musa/lib/MTGPUToLLVM/MTGPUToLLVMPass.cpp b/third_party/mthreads/musa/lib/MTGPUToLLVM/MTGPUToLLVMPass.cpp new file mode 100644 index 0000000000..40f39e2095 --- /dev/null +++ b/third_party/mthreads/musa/lib/MTGPUToLLVM/MTGPUToLLVMPass.cpp @@ -0,0 +1,163 @@ +#include "MTGPUToLLVM/MTGPUToLLVMPass.h" + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "../TritonMUSAGPUToLLVM/DotOpToLLVM/DotOpToLLVM.h" + +#include + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_CONVERTMTGPUTOLLVM +#include "musa/include/MTGPUToLLVM/Passes.h.inc" + +namespace mtgpu { + +namespace { + +struct SqmmaOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::mtgpu::SqmmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value threadId = getThreadId(rewriter, op.getLoc()); + if (failed(mlir::triton::MUSA::convertSQMMADot( + op, adaptor, this->getTypeConverter(), rewriter, threadId))) { + return op.emitError("MUSA SQMMA: mtgpu lowering failed"); + } + return success(); + } +}; + +struct PackSqmmaAccumulatorOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::mtgpu::PackSqmmaAccumulatorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::mtgpu::PackSqmmaAccumulatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tensorTy = cast(op.getInput().getType()); + Value packed = mlir::LLVM::MUSA::packSqmmaAccumulatorCarrierFromTensor( + op.getLoc(), adaptor.getInput(), tensorTy, this->getTypeConverter(), + rewriter); + rewriter.replaceOp(op, packed); + return success(); + } +}; + +struct UnpackSqmmaAccumulatorOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::mtgpu::UnpackSqmmaAccumulatorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::mtgpu::UnpackSqmmaAccumulatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tensorTy = cast(op.getOutput().getType()); + Value unpacked = mlir::LLVM::MUSA::unpackSqmmaAccumulatorCarrierToTensor( + op.getLoc(), adaptor.getCarrier(), tensorTy, this->getTypeConverter(), + rewriter); + rewriter.replaceOp(op, unpacked); + return success(); + } +}; + +struct SqmmaWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::mtgpu::SqmmaWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::mtgpu::SqmmaWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::createLLVMIntrinsicCallOp(rewriter, op.getLoc(), + "llvm.musa.sqmma.wait", TypeRange{}, {}); + rewriter.replaceOp(op, adaptor.getInputs()); + return success(); + } +}; + +} // namespace + +void populateMTGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add( + typeConverter, benefit); +} + +} // namespace mtgpu + +namespace { + +class ConvertMTGPUToLLVM + : public impl::ConvertMTGPUToLLVMBase { +public: + using impl::ConvertMTGPUToLLVMBase< + ConvertMTGPUToLLVM>::ConvertMTGPUToLLVMBase; + + ConvertMTGPUToLLVM() = default; + ConvertMTGPUToLLVM(int32_t computeCapability) + : impl::ConvertMTGPUToLLVMBase({computeCapability}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mlir::triton::MUSA::TargetInfo targetInfo(computeCapability); + + LowerToLLVMOptions options(context); + options.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, options, targetInfo); + typeConverter.addConversion( + [&](triton::mtgpu::SqmmaAccumulatorType type) -> std::optional { + auto info = LLVM::MUSA::getSqmmaAccumulatorCarrierInfo(type); + if (failed(info)) + return std::nullopt; + return info->carrierType; + }); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addLegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + RewritePatternSet patterns(context); + mtgpu::populateMTGPUToLLVMPatterns(typeConverter, patterns, + PatternBenefit(1)); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> createConvertMTGPUToLLVMPass() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertMTGPUToLLVMPass(int32_t computeCapability) { + return std::make_unique(computeCapability); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/AllocateSharedMemory.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 0000000000..829412985b --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,269 @@ +#include "TritonMUSAGPUToLLVM/Allocation.h" +#include "TritonMUSAGPUToLLVM/Passes.h" +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/AllocateSharedMemoryUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace mlir::triton { +#define GEN_PASS_DEF_ALLOCATEMUSASHAREDMEMORY +#include "musa/include/TritonMUSAGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +static bool isMusaSqmmaLike(Attribute layout) { + return isa(layout); +} + +static bool useMusaReplicatedScratch(Attribute srcLayout, Attribute dstLayout) { + return (isMusaSqmmaLike(srcLayout) || isMusaSqmmaLike(dstLayout)) && + isa( + srcLayout) && + isa( + dstLayout); +} + +static bool isSqmmaAccumulatorToBlockedLike(Attribute srcLayout, + Attribute dstLayout) { + return isa(srcLayout) && + isa(dstLayout); +} + +static bool useMusaSqmmaBlockSwizzling(RankedTensorType srcTy, + RankedTensorType dstTy) { + if (!(isMusaSqmmaLike(srcTy.getEncoding()) || + isMusaSqmmaLike(dstTy.getEncoding()))) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + auto dims = conversion.getInDimNames(); + + if (llvm::is_contained(dims, kBlock)) + return false; + if (llvm::is_contained(dims, kWarp)) + return true; + if (llvm::is_contained(dims, kLane)) + return !cvtNeedsWarpShuffle(srcTy, dstTy); + return false; +} + +static bool isPlainBlockedLike(Attribute layout) { + return isa(layout); +} + +static bool useConservativeCarrierScratch(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcElemTy = srcTy.getElementType(); + auto dstElemTy = dstTy.getElementType(); + + auto needsByteCarrier = [](Type ty) { + return ty.isIntOrFloat() && ty.getIntOrFloatBitWidth() < 8; + }; + bool isPointerCarrier = isa(srcElemTy) && + isa(dstElemTy); + bool isSubByteCarrier = + needsByteCarrier(srcElemTy) && needsByteCarrier(dstElemTy); + if (!isPointerCarrier && !isSubByteCarrier) + return false; + + if (!isa(srcTy.getEncoding()) || + !isa(dstTy.getEncoding())) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + StringAttr kBlock = str_attr("block"); + return !llvm::is_contained(conversion.getInDimNames(), kBlock); +} + +static bool useMusaGenericBlockSwizzling(RankedTensorType srcTy, + RankedTensorType dstTy) { + if (isMusaSqmmaLike(srcTy.getEncoding()) || + isMusaSqmmaLike(dstTy.getEncoding())) + return false; + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + StringAttr kBlock = str_attr("block"); + return !llvm::is_contained(conversion.getInDimNames(), kBlock); +} + +static LinearLayout +getMusaSwizzledScratchLayout(RankedTensorType srcTy, RankedTensorType dstTy, + const TargetInfoBase &targetInfo) { + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout); + dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout); + auto bitwidth = getBitwidth(srcTy); + auto [srcTiles, dstTiles] = getSrcDstTiles(targetInfo, bitwidth); + auto [smem, _] = + optimalSwizzling(srcLayout, dstLayout, srcTiles, dstTiles, bitwidth); + return smem; +} + +static bool hasPH1PhysicalSliceRep(const LinearLayout &smem) { + constexpr int32_t kPH1PhysicalSliceRows = 32; + + auto *ctx = smem.getInDimNames().begin()->getContext(); + auto kReps = StringAttr::get(ctx, "reps"); + auto outDims = smem.getOutDims(); + if (outDims.size() < 2) + return false; + + auto isPowerOfTwo = [](int32_t value) { + return value > 0 && (value & (value - 1)) == 0; + }; + + for (const auto &repBasis : smem.getBases().lookup(kReps)) { + for (auto [dim, component] : llvm::enumerate(repBasis)) { + if (component == 0) + continue; + if (dim + 1 >= outDims.size()) + continue; + if (!isPowerOfTwo(component) || component < kPH1PhysicalSliceRows) + continue; + + bool selectsWholeInnerSlice = true; + for (auto inner = dim + 1; inner < repBasis.size(); ++inner) { + if (repBasis[inner] != 0) { + selectsWholeInnerSlice = false; + break; + } + } + if (selectsWholeInnerSlice) + return true; + } + } + return false; +} + +static bool +needsMusaRepDisjointGenericScratchImpl(RankedTensorType srcTy, + RankedTensorType dstTy, + const TargetInfoBase &targetInfo) { + if (!(isPlainBlockedLike(srcTy.getEncoding()) && + isPlainBlockedLike(dstTy.getEncoding()))) + return false; + if (getBitwidth(srcTy) != 64) + return false; + if (!useMusaGenericBlockSwizzling(srcTy, dstTy)) + return false; + return hasPH1PhysicalSliceRep( + getMusaSwizzledScratchLayout(srcTy, dstTy, targetInfo)); +} + +static unsigned getFullLogicalScratchBytes(RankedTensorType ty) { + auto elems = product(getShapePerCTA(ty)); + return elems * getBitwidth(ty) / 8; +} + +static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy, + RankedTensorType dstTy, + const TargetInfoBase &targetInfo, + bool separateRepScratch) { + auto *ctx = srcTy.getContext(); + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout); + dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout); + auto bitwidth = getBitwidth(srcTy); + auto [srcTiles, dstTiles] = getSrcDstTiles(targetInfo, bitwidth); + auto [smem, _] = + optimalSwizzling(srcLayout, dstLayout, srcTiles, dstTiles, bitwidth); + if (separateRepScratch) + return smem.getTotalOutDimSize(); + auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps")); + return smem.getTotalOutDimSize() / reps; +} + +static unsigned getMusaScratchSizeInBytes(Operation *op, + const TargetInfoBase &targetInfo) { + auto cvtOp = dyn_cast(op); + if (!cvtOp) + return defaultAllocationAnalysisScratchSizeFn(op); + + auto srcTy = cvtOp.getSrc().getType(); + auto dstTy = cvtOp.getType(); + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return 0; + + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (useMusaReplicatedScratch(srcLayout, dstLayout) && + !isSqmmaAccumulatorToBlockedLike(srcLayout, dstLayout)) + return getFullLogicalScratchBytes(srcTy); + + if (useConservativeCarrierScratch(srcTy, dstTy)) + return getFullLogicalScratchBytes(srcTy); + + if (isSqmmaAccumulatorToBlockedLike(srcLayout, dstLayout) || + useMusaSqmmaBlockSwizzling(srcTy, dstTy) || + useMusaGenericBlockSwizzling(srcTy, dstTy)) { + bool separateRepScratch = + mlir::triton::musa_gpu::needsMusaRepDisjointGenericScratch(srcTy, dstTy, + targetInfo); + auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy, targetInfo, + separateRepScratch); + return elems * getBitwidth(srcTy) / 8; + } + + return defaultAllocationAnalysisScratchSizeFn(op); +} + +struct AllocateMUSASharedMemory + : public mlir::triton::impl::AllocateMUSASharedMemoryBase< + AllocateMUSASharedMemory> { + using AllocateMUSASharedMemoryBase::AllocateMUSASharedMemoryBase; + + AllocateMUSASharedMemory(int32_t computeCapability) + : AllocateMUSASharedMemoryBase({computeCapability}) {} + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MUSA::TargetInfo targetInfo(computeCapability); + ModuleAllocation allocation( + mod, mlir::triton::musa_gpu::getMusaAllocationAnalysisScratchSizeFn( + targetInfo)); + mlir::triton::gpu::attachAllocationSizeAndOffsetAttr(mod, allocation); + } +}; + +} // namespace + +namespace mlir::triton { +namespace musa_gpu { +bool needsMusaRepDisjointGenericScratch(RankedTensorType srcTy, + RankedTensorType dstTy, + const TargetInfoBase &targetInfo) { + return needsMusaRepDisjointGenericScratchImpl(srcTy, dstTy, targetInfo); +} + +std::function +getMusaAllocationAnalysisScratchSizeFn(const TargetInfoBase &targetInfo) { + return [&targetInfo](Operation *op) { + return getMusaScratchSizeInBytes(op, targetInfo); + }; +} +} // namespace musa_gpu + +std::unique_ptr> +createAllocateMUSASharedMemoryPass(int32_t computeCapability) { + return std::make_unique(computeCapability); +} +} // namespace mlir::triton diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/BarrierOpToLLVM.cpp new file mode 100644 index 0000000000..d1b9d24a9d --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/BarrierOpToLLVM.cpp @@ -0,0 +1,57 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; + +namespace { + +struct TTGBarrierOpConversion + : public ConvertOpToLLVMPattern { + TTGBarrierOpConversion(LLVMTypeConverter &typeConverter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + targetInfo.barrier(op.getLoc(), rewriter, op.getAddrSpace()); + rewriter.eraseOp(op); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; +}; + +struct GPUBarrierOpConversion + : public ConvertOpToLLVMPattern { + GPUBarrierOpConversion(LLVMTypeConverter &typeConverter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + targetInfo.barrier(op.getLoc(), rewriter, triton::gpu::AddrSpace::Local); + rewriter.eraseOp(op); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::MUSA::populateBarrierOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit, const TargetInfo &targetInfo) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..67ac8cba9d --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt @@ -0,0 +1,33 @@ +add_triton_library(TritonMUSAGPUToLLVM + AllocateSharedMemory.cpp + BarrierOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + DotOpToLLVM.cpp + DotOpToLLVM/SQMMA.cpp + DotOpToLLVM/WMMA.cpp + ElementwiseOpToLLVM.cpp + LoadStoreOpToLLVM.cpp + MUSAOpsToLLVM.cpp + SPMDOpToLLVM.cpp + TensorPtrOpsToLLVM.cpp + TritonGPUToLLVM.cpp + TargetInfo.cpp + ThreadIdOpToLLVM.cpp + Utility.cpp + WarpIdOpToLLVM.cpp + + DEPENDS + TritonMUSAGPUConversionPassIncGen + MUSAAttrDefsIncGen + MTGPUTableGen + MTGPUTypesIncGen + + LINK_LIBS PUBLIC + TritonAnalysis + TritonGPUToLLVM + MLIRReconcileUnrealizedCasts + MLIRUBToLLVM + MTGPUIR + MUSAIR + MLIRGPUToMTGPUTransforms +) diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 0000000000..081691102c --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,551 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonMUSAGPUToLLVM/Allocation.h" +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" + +using ::mlir::LLVM::linearize; + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +static bool isMusaSqmmaLike(Attribute layout) { + return isa(layout); +} + +static bool useMusaReplicatedScratch(Attribute srcLayout, Attribute dstLayout) { + return (isMusaSqmmaLike(srcLayout) || isMusaSqmmaLike(dstLayout)) && + isa( + srcLayout) && + isa( + dstLayout); +} + +static bool isSqmmaAccumulatorToBlockedLike(Attribute srcLayout, + Attribute dstLayout) { + return isa(srcLayout) && + isa(dstLayout); +} + +static bool useMusaSqmmaBlockSwizzling(RankedTensorType srcTy, + RankedTensorType dstTy) { + if (!(isMusaSqmmaLike(srcTy.getEncoding()) || + isMusaSqmmaLike(dstTy.getEncoding()))) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + auto dims = conversion.getInDimNames(); + + if (llvm::is_contained(dims, kBlock)) + return false; + if (llvm::is_contained(dims, kWarp)) + return true; + if (llvm::is_contained(dims, kLane)) + return !cvtNeedsWarpShuffle(srcTy, dstTy); + return false; +} + +static bool useConservativeCarrierScratch(RankedTensorType srcTy, + RankedTensorType dstTy) { + MLIRContext *ctx = srcTy.getContext(); + auto srcElemTy = srcTy.getElementType(); + auto dstElemTy = dstTy.getElementType(); + + auto needsByteCarrier = [](Type ty) { + return ty.isIntOrFloat() && ty.getIntOrFloatBitWidth() < 8; + }; + bool isPointerCarrier = isa(srcElemTy) && + isa(dstElemTy); + bool isSubByteCarrier = + needsByteCarrier(srcElemTy) && needsByteCarrier(dstElemTy); + if (!isPointerCarrier && !isSubByteCarrier) + return false; + + if (!isa(srcTy.getEncoding()) || + !isa(dstTy.getEncoding())) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + StringAttr kBlock = str_attr("block"); + return !llvm::is_contained(conversion.getInDimNames(), kBlock); +} + +static bool isPlainBlockedLike(Attribute layout) { + return isa(layout); +} + +static bool useMusaGenericBlockSwizzling(RankedTensorType srcTy, + RankedTensorType dstTy) { + if (isMusaSqmmaLike(srcTy.getEncoding()) || + isMusaSqmmaLike(dstTy.getEncoding())) + return false; + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return false; + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + StringAttr kBlock = str_attr("block"); + return !llvm::is_contained(conversion.getInDimNames(), kBlock); +} + +static bool hasOnlyStorageLikeUsers(Value value) { + SmallVector worklist{value}; + llvm::SmallPtrSet visited; + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (Operation *user : current.getUsers()) { + if (!visited.insert(user).second) + continue; + if (isa(user)) + continue; + if (isa(user)) { + for (Value result : user->getResults()) + worklist.push_back(result); + continue; + } + return false; + } + } + return true; +} + +static triton::gpu::LocalAllocOp findRootLocalAlloc(Value memDesc) { + Value cur = memDesc; + while (cur) { + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + break; + if (auto localAllocOp = dyn_cast(defOp)) + return localAllocOp; + if (auto indexOp = dyn_cast(defOp)) { + cur = indexOp.getSrc(); + continue; + } + if (auto subsliceOp = dyn_cast(defOp)) { + cur = subsliceOp.getSrc(); + continue; + } + if (auto reinterpretOp = + dyn_cast(defOp)) { + cur = reinterpretOp.getSrc(); + continue; + } + if (auto transOp = dyn_cast(defOp)) { + cur = transOp.getSrc(); + continue; + } + if (auto reshapeOp = dyn_cast(defOp)) { + cur = reshapeOp.getSrc(); + continue; + } + break; + } + return {}; +} + +static FailureOr getDistributedSharedMemoryBase( + Location loc, ConversionPatternRewriter &rewriter, + const MUSA::TargetInfo &targetInfo, triton::gpu::ConvertLayoutOp op, + Attribute srcLayout) { + if (isMusaSqmmaLike(srcLayout)) + return LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + + auto rootLocalAlloc = findRootLocalAlloc(op.getSrc()); + if (rootLocalAlloc && rootLocalAlloc->hasAttr("allocation.offset")) + return LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + rootLocalAlloc.getOperation()); + + if (op->hasAttr("allocation.offset")) + return LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + + return rewriter.notifyMatchFailure( + op, "expected allocation.offset on convert_layout or root local_alloc " + "for MUSA shared layout conversion"); +} + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const MUSA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (useMusaReplicatedScratch(srcLayout, dstLayout)) { + if (isSqmmaAccumulatorToBlockedLike(srcLayout, dstLayout)) + return lowerSqmmaBlockSwizzling(op, adaptor, rewriter); + return lowerDistributedToDistributed(op, adaptor, rewriter); + } + + if (useConservativeCarrierScratch(srcTy, dstTy)) + return lowerDistributedToDistributed(op, adaptor, rewriter); + + if (useMusaSqmmaBlockSwizzling(srcTy, dstTy)) + return lowerSqmmaBlockSwizzling(op, adaptor, rewriter); + + if (useMusaGenericBlockSwizzling(srcTy, dstTy)) + return lowerGenericBlockSwizzling(op, adaptor, rewriter); + + return failure(); + } + +private: + SmallVector transferWithinBlockSwizzlingImpl( + Location loc, ConversionPatternRewriter &rewriter, + const LinearLayout &srcLayout, const LinearLayout &dstLayout, + ArrayRef inVals, Type llvmElemTy, Value smemBase, + bool separateRepScratch) const { + auto *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (isa(llvmElemTy)) { + auto llvmElemTyPtr = i64_ty; + auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) { + return b.ptrtoint(llvmElemTyPtr, v).getResult(); + })); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, newInVals, llvmElemTyPtr, + smemBase, separateRepScratch); + for (auto &v : outVals) + v = b.inttoptr(llvmElemTy, v); + return outVals; + } + + if (llvmElemTy.getIntOrFloatBitWidth() < 8) { + auto i8ElemTy = i8_ty; + auto newInVals = llvm::to_vector(llvm::map_range( + inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); })); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase, + separateRepScratch); + for (auto &v : outVals) + v = b.trunc(llvmElemTy, v); + return outVals; + } + + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout); + if (!removeBroadcastSrc.isIdentity()) { + auto prmtSrc = removeBroadcastSrc.apply(srcLayout); + auto newInVals = removeBroadcastSrc.apply(inVals); + return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout, + newInVals, llvmElemTy, smemBase, + separateRepScratch); + } + + auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout); + if (!removeBroadcastDst.isIdentity()) { + auto prmtDst = removeBroadcastDst.apply(dstLayout); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase, + separateRepScratch); + return broadcastAs(outVals, dstLayout); + } + + auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto [srcTiles, dstTiles] = getSrcDstTiles(targetInfo, bitwidth); + auto [smem, instr] = + optimalSwizzling(srcLayout, dstLayout, srcTiles, dstTiles, bitwidth); + auto [idxSrc, idxDst] = instr; + assert(idxSrc == 0 && idxDst == 0 && + "MUSA generic swizzling currently supports shared ld/st tiles only"); + + auto kReg = str_attr("register"); + auto kReps = str_attr("reps"); + auto nReps = smem.getInDimSize(kReps); + auto reps = LinearLayout::identity1D(nReps, kReg, kReps); + + auto totalStoreCvt = srcLayout.invertAndCompose(smem); + auto totalLoadCvt = dstLayout.invertAndCompose(smem); + + auto permStore = + regPermForDivide(totalStoreCvt, reps, /*left=*/false).value(); + totalStoreCvt = permStore.apply(totalStoreCvt); + auto permutedInVals = permStore.apply(inVals); + auto permLoad = + regPermForDivide(totalLoadCvt, reps, /*left=*/false).value(); + totalLoadCvt = permLoad.apply(totalLoadCvt); + + auto storeCvt = *divideRight(totalStoreCvt, reps); + auto loadCvt = *divideRight(totalLoadCvt, reps); + auto kOffset = str_attr("offset"); + storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}}); + loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}}); + + auto tileSize = storeCvt.getInDimSize(kReg); + assert(permutedInVals.size() == tileSize * nReps); + + SmallVector outVals; + auto maskSpanAffineOffset = 0; + + bool isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout); + for (int i = 0; i < nReps; ++i) { + if (i > 0) { + if (isWarpSync) + targetInfo.warpSync(loc, rewriter); + else + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + } + + auto tileInVals = + ArrayRef(permutedInVals).slice(i * tileSize, tileSize); + Value affineOffset = separateRepScratch + ? b.i32_val(i * storeCvt.getTotalOutDimSize()) + : b.i32_val(0); + lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase, + /*paddingShifts=*/{}, affineOffset, maskSpanAffineOffset, + rewriter, targetInfo); + + if (isWarpSync) + targetInfo.warpSync(loc, rewriter); + else + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + + SmallVector tileOutVals = lowerLdStShared( + loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /*paddingShifts=*/{}, + affineOffset, maskSpanAffineOffset, rewriter, targetInfo); + llvm::append_range(outVals, tileOutVals); + } + + outVals = permLoad.inverse().apply(outVals); + return outVals; + } + + LogicalResult + lowerGenericBlockSwizzling(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + SmallVector inDims{kReg, kLane, kWarp}; + srcLayout = + srcLayout.sublayout(inDims, to_vector(srcLayout.getOutDimNames())); + dstLayout = + dstLayout.sublayout(inDims, to_vector(dstLayout.getOutDimNames())); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + bool separateRepScratch = + musa_gpu::needsMusaRepDisjointGenericScratch(srcTy, dstTy, targetInfo); + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase, + separateRepScratch); + Value result = ::mlir::packLLElements(loc, getTypeConverter(), outVals, + rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult + lowerSqmmaBlockSwizzling(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + auto srcLayout = toLinearLayout(srcTy); + auto dstLayout = toLinearLayout(dstTy); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + SmallVector inDims{kReg, kLane, kWarp}; + srcLayout = + srcLayout.sublayout(inDims, to_vector(srcLayout.getOutDimNames())); + dstLayout = + dstLayout.sublayout(inDims, to_vector(dstLayout.getOutDimNames())); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + + if (isMusaSqmmaLike(srcTy.getEncoding())) + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + + auto outVals = transferWithinBlockSwizzlingImpl( + loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase, + /*separateRepScratch=*/false); + Value result = ::mlir::packLLElements(loc, getTypeConverter(), outVals, + rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult + lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto typeConverter = getTypeConverter(); + + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (product(convertType(dstTy.getShape())) == 1) { + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(getTotalElemsPerThread(dstTy), inVals[0]); + Value result = + ::mlir::packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + auto shapePerCTA = convertType(getShapePerCTA(srcTy)); + unsigned rank = dstTy.getRank(); + + SmallVector repShape(shapePerCTA.begin(), shapePerCTA.end()); + SmallVector numReplicates(rank, 1); + auto order = getOrder(dstTy); + auto dstElemTy = dstTy.getElementType(); + auto llvmElemTy = typeConverter->convertType(dstElemTy); + bool isPtr = isa(dstElemTy); + bool useByteCarrier = + dstElemTy.isIntOrFloat() && dstElemTy.getIntOrFloatBitWidth() < 8; + Type llvmElemStorageTy = llvmElemTy; + if (isPtr) + llvmElemStorageTy = i64_ty; + else if (useByteCarrier) + llvmElemStorageTy = i8_ty; + auto elemPtrTy = ptr_ty(ctx, 3); + + auto srcIndices = emitIndices(loc, rewriter, targetInfo, srcLayout, srcTy, + /*withCTAOffset=*/false); + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (isPtr) { + for (Value &inVal : inVals) + inVal = b.ptrtoint(i64_ty, inVal); + } else if (useByteCarrier) { + for (Value &inVal : inVals) + inVal = b.zext(i8_ty, inVal); + } + assert(srcIndices.size() == inVals.size() && + "unexpected source index/value mismatch"); + + auto smemBaseOr = getDistributedSharedMemoryBase(loc, rewriter, targetInfo, + op, srcLayout); + if (failed(smemBaseOr)) + return failure(); + Value smemBase = *smemBaseOr; + Value typedSmemBase = b.bitcast(smemBase, elemPtrTy); + + auto dstIndices = emitIndices(loc, rewriter, targetInfo, dstLayout, dstTy, + /*withCTAOffset=*/false); + SmallVector outVals( + dstIndices.size(), + LLVM::UndefOp::create(rewriter, loc, llvmElemStorageTy)); + + unsigned numTotalReps = product(numReplicates); + if (numTotalReps != 0 && isMusaSqmmaLike(srcLayout)) + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + for (unsigned repId = 0; repId < numTotalReps; ++repId) { + if (repId != 0) + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + + auto multiDimRepId = delinearize(repId, numReplicates, order); + SmallVector repBase(rank); + SmallVector repLimit(rank); + for (unsigned d = 0; d < rank; ++d) { + repBase[d] = b.i32_val(multiDimRepId[d] * repShape[d]); + repLimit[d] = b.i32_val(multiDimRepId[d] * repShape[d] + repShape[d]); + } + + for (unsigned i = 0; i < srcIndices.size(); ++i) { + Value inRep = b.true_val(); + SmallVector localCoord(rank); + for (unsigned d = 0; d < rank; ++d) { + Value ge = b.icmp_sge(srcIndices[i][d], repBase[d]); + Value lt = b.icmp_slt(srcIndices[i][d], repLimit[d]); + inRep = b.and_(inRep, b.and_(ge, lt)); + localCoord[d] = b.sub(srcIndices[i][d], repBase[d]); + } + + Value offset = linearize(rewriter, loc, localCoord, repShape, order); + Value ptr = b.gep(elemPtrTy, llvmElemStorageTy, typedSmemBase, offset); + LLVM::MUSA::llStore(rewriter, loc, ptr, inVals[i], inRep); + } + + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + + for (unsigned i = 0; i < dstIndices.size(); ++i) { + Value inRep = b.true_val(); + SmallVector localCoord(rank); + for (unsigned d = 0; d < rank; ++d) { + Value ge = b.icmp_sge(dstIndices[i][d], repBase[d]); + Value lt = b.icmp_slt(dstIndices[i][d], repLimit[d]); + inRep = b.and_(inRep, b.and_(ge, lt)); + localCoord[d] = b.sub(dstIndices[i][d], repBase[d]); + } + + Value offset = linearize(rewriter, loc, localCoord, repShape, order); + Value ptr = b.gep(elemPtrTy, llvmElemStorageTy, typedSmemBase, offset); + outVals[i] = LLVM::MUSA::llLoad(rewriter, loc, ptr, llvmElemStorageTy, + inRep, outVals[i]); + } + } + + if (isPtr) { + for (Value &outVal : outVals) + outVal = b.inttoptr(llvmElemTy, outVal); + } else if (useByteCarrier) { + for (Value &outVal : outVals) + outVal = b.trunc(llvmElemTy, outVal); + } + + Value result = + ::mlir::packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const MUSA::TargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::MUSA::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, PatternBenefit(benefit.getBenefit() + 1)); + mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM.cpp new file mode 100644 index 0000000000..96d3233c7e --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM.cpp @@ -0,0 +1,48 @@ +#include "DotOpToLLVM/DotOpToLLVM.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +namespace { + +struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = cast(op.getType()); + auto lhsTy = cast(op.getA().getType()); + auto rhsTy = cast(op.getB().getType()); + if ((op.getInputPrecision() == triton::InputPrecision::BF16x3 || + op.getInputPrecision() == triton::InputPrecision::BF16x6) && + lhsTy.getElementType().isF32() && rhsTy.getElementType().isF32()) { + return op.emitError( + "bf16x3/bf16x6 tt.dot must be rewritten by TritonGPUF32DotTC " + "before MUSA LLVM lowering"); + } + if (isa( + resultTy.getEncoding())) + return op.emitError("MUSA matmul with mma encoding must be rewritten to " + "ttmg.wmma_dot/ttmg.squad_dot before LLVM lowering"); + if (isa(resultTy.getEncoding())) + return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + + llvm::report_fatal_error( + "Unsupported MUSA DotOp encoding in DotOp lowering."); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateDotOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/DotOpToLLVM.h b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/DotOpToLLVM.h new file mode 100644 index 0000000000..5fcd92820d --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/DotOpToLLVM.h @@ -0,0 +1,35 @@ +#ifndef TRITONMUSAGPU_CONVERSION_DOTOP_TO_LLVM_H +#define TRITONMUSAGPU_CONVERSION_DOTOP_TO_LLVM_H + +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace MUSA { + +LogicalResult convertWMMADot(triton::musa::WmmaDotOp op, + triton::musa::WmmaDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertSQMMADot(triton::musa::SquadDotOp op, + triton::musa::SquadDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value threadId); + +LogicalResult convertSQMMADot(triton::mtgpu::SqmmaOp op, + triton::mtgpu::SqmmaOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value threadId); + +} // namespace MUSA +} // namespace triton +} // namespace mlir + +#endif // TRITONMUSAGPU_CONVERSION_DOTOP_TO_LLVM_H diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/SQMMA.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/SQMMA.cpp new file mode 100644 index 0000000000..b08b903122 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/SQMMA.cpp @@ -0,0 +1,889 @@ +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "DotOpToLLVM.h" +#include "TritonMUSACommon/MMAContractUtils.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +triton::musa::SQMMAEltType getMmaRetType(Value d) { + auto dTy = cast(d.getType()).getElementType(); + if (dTy.isF32()) + return triton::musa::SQMMAEltType::f32; + if (dTy.isInteger(32)) + return triton::musa::SQMMAEltType::s32; + if (dTy.isF16()) + return triton::musa::SQMMAEltType::f16; + llvm::report_fatal_error("MUSA SQMMA: unsupported result type"); +} + +triton::musa::SQMMAEltType getMmaOperandType(Value a, bool allowTF32) { + auto aTy = cast(a.getType()).getElementType(); + if (aTy.isF16()) + return triton::musa::SQMMAEltType::f16; + if (aTy.isBF16()) + return triton::musa::SQMMAEltType::bf16; + if (aTy.isF32() && allowTF32) + return triton::musa::SQMMAEltType::tf32; + if (aTy.isInteger(8)) + return triton::musa::SQMMAEltType::s8; + if (llvm::isa(aTy)) + return triton::musa::SQMMAEltType::e5m2; + if (llvm::isa(aTy)) + return triton::musa::SQMMAEltType::e4m3; + llvm::report_fatal_error("MUSA SQMMA: unsupported operand type"); +} + +bool isFP8(triton::musa::SQMMAEltType type) { + return type == triton::musa::SQMMAEltType::e4m3 || + type == triton::musa::SQMMAEltType::e5m2; +} + +struct SqmmaDescriptorIntrinsicSpec { + llvm::StringRef intrinsicName; + Type loadType; + Type dataType; + Value leadingStrideBytes; + Value swizzleGranularity; + Value swizzleStride; +}; + +static FailureOr +buildSqmmaDescriptorIntrinsicSpec(Location loc, MemDescType memDescTy, + ArrayRef physicalShape, + triton::musa::SQMMAEltType eltType, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto order = + triton::musa::getSharedOrder(memDescTy.getEncoding(), physicalShape); + if (physicalShape.size() != 2 || order.size() < 2) + return failure(); + + int elemBits = memDescTy.getElementTypeBitWidth(); + if (elemBits <= 0 || (elemBits % 8) != 0) + return failure(); + int64_t elemBytes = elemBits / 8; + auto grouping = triton::musa::getPH1TMELeadingDimGrouping(physicalShape, + order, elemBytes); + if (failed(grouping)) + return failure(); + int64_t groupedLeadingWidthBytes = + grouping->elemsPerGroupInLeadingDim * elemBytes; + if (groupedLeadingWidthBytes <= 0 || + !llvm::isPowerOf2_64(static_cast(groupedLeadingWidthBytes))) + return failure(); + + auto swizzle = triton::musa::resolveTMESwizzleConfigFromEncoding(memDescTy); + if (failed(swizzle)) + return failure(); + + Type loadType = typeConverter->convertType(memDescTy.getElementType()); + if (!loadType) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + llvm::StringRef intrinsicName; + Type dataType; + switch (eltType) { + case triton::musa::SQMMAEltType::f16: + case triton::musa::SQMMAEltType::bf16: + intrinsicName = "llvm.musa.sqmma.desc.half"; + dataType = rewriter.getI16Type(); + break; + case triton::musa::SQMMAEltType::tf32: + intrinsicName = "llvm.musa.sqmma.desc.fp32"; + dataType = rewriter.getF32Type(); + break; + case triton::musa::SQMMAEltType::s8: + case triton::musa::SQMMAEltType::e4m3: + case triton::musa::SQMMAEltType::e5m2: + intrinsicName = "llvm.musa.sqmma.desc.i8"; + dataType = IntegerType::get(ctx, 8); + break; + default: + return failure(); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + return SqmmaDescriptorIntrinsicSpec{ + intrinsicName, + loadType, + dataType, + b.i32_val(static_cast(groupedLeadingWidthBytes)), + b.i32_val(static_cast(swizzle->swizzleGranularity)), + b.i32_val(static_cast(swizzle->swizzleStride)), + }; +} + +class SqmmaSmemLoader { +public: + SqmmaSmemLoader() = default; + SqmmaSmemLoader(Value tensor, Value affineBase, + ArrayRef semanticShape, + ArrayRef physicalShape, Value warpId, + unsigned dimWpt, bool trans, unsigned nonKTile, + unsigned kTile, SqmmaDescriptorIntrinsicSpec descriptorSpec, + ConversionPatternRewriter &rewriter, Location loc) + : base(affineBase), + semanticShape(semanticShape.begin(), semanticShape.end()), + physicalShape(physicalShape.begin(), physicalShape.end()), + warpId(warpId), dimWpt(dimWpt), trans(trans), nonKTile(nonKTile), + kTile(kTile), descriptorSpec(descriptorSpec) { + auto ty = cast(tensor.getType()); + ord = triton::musa::getSharedOrder(ty.getEncoding(), this->physicalShape); + elemBytes = ty.getElementTypeBitWidth() / 8; + uint32_t widthInByte = this->physicalShape[ord[0]] * elemBytes; + elemsPerSwizzlingRow = + widthInByte > 256 ? 256 / elemBytes : this->physicalShape[ord[0]]; + elemsPerSwizzlingRowVal = + TritonLLVMOpBuilder(loc, rewriter).i32_val(elemsPerSwizzlingRow); + } + + Value smemDesc(unsigned tileNonKIdx, unsigned tileKIdx, + ConversionPatternRewriter &rewriter, Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value k = b.i32_val(tileKIdx * kTile); + Value nonK = b.add(b.i32_val(tileNonKIdx * dimWpt * nonKTile), + b.mul(warpId, b.i32_val(nonKTile))); + if (trans) + std::swap(k, nonK); + + Value leadingOffset = + b.mul(b.udiv(k, elemsPerSwizzlingRowVal), + b.i32_val(physicalShape[ord[1]] * elemsPerSwizzlingRow)); + Value strideOffset = b.mul(nonK, elemsPerSwizzlingRowVal); + Value offset = b.add(b.add(leadingOffset, strideOffset), + b.urem(k, elemsPerSwizzlingRowVal)); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value elemBase = b.bitcast(base, elemPtrTy); + Value elemPtr = b.gep(elemPtrTy, descriptorSpec.loadType, elemBase, offset); + Value data = + LLVM::LoadOp::create(rewriter, loc, descriptorSpec.loadType, elemPtr); + if (data.getType() != descriptorSpec.dataType) + data = b.bitcast(data, descriptorSpec.dataType); + auto desc = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, descriptorSpec.intrinsicName, TypeRange{i32_ty}, + ValueRange{data, descriptorSpec.leadingStrideBytes, + descriptorSpec.swizzleGranularity, + descriptorSpec.swizzleStride}); + return desc.getResult(0); + } + +private: + Value base; + SmallVector semanticShape; + SmallVector physicalShape; + Value warpId; + unsigned dimWpt = 1; + bool trans = false; + unsigned nonKTile = 0; + unsigned kTile = 0; + SmallVector ord; + int elemsPerSwizzlingRow = 0; + int elemBytes = 0; + Value elemsPerSwizzlingRowVal; + SqmmaDescriptorIntrinsicSpec descriptorSpec; +}; + +SmallVector loadAccSlice(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &elements, + int startIndex, int numElements, + Operation *insertBefore) { + OpBuilder::InsertionGuard g(rewriter); + if (insertBefore) + rewriter.setInsertionPoint(insertBefore); + + SmallVector slice(numElements); + for (int i = 0; i < numElements; ++i) + slice[i] = elements[startIndex + i]; + return slice; +} + +Value selectAccumulatorVector(Location loc, Value useC, Value acc, Value zero, + ConversionPatternRewriter &rewriter) { + return triton::musa::selectAccumulatorValue(loc, useC, acc, zero, rewriter); +} + +Value addAccumulate(ConversionPatternRewriter &rewriter, Location loc, Value a, + Value b) { + auto vecTy = cast(a.getType()); + auto bld = TritonLLVMOpBuilder(loc, rewriter); + Value res = bld.undef(vecTy); + for (int i = 0; i < vecTy.getNumElements(); ++i) { + Value lhs = bld.extract_element(a, bld.i32_val(i)); + Value rhs = bld.extract_element(b, bld.i32_val(i)); + Value add = vecTy.getElementType().isInteger() + ? LLVM::AddOp::create(rewriter, loc, lhs, rhs).getResult() + : LLVM::FAddOp::create(rewriter, loc, lhs, rhs).getResult(); + res = bld.insert_element(res, add, bld.i32_val(i)); + } + return res; +} + +Value addAccumulatorLane(ConversionPatternRewriter &rewriter, Location loc, + Value a, Value b) { + if (a.getType().isInteger()) + return LLVM::AddOp::create(rewriter, loc, a, b).getResult(); + return LLVM::FAddOp::create(rewriter, loc, a, b).getResult(); +} + +static RankedTensorType getSqmmaAccumulatorTensorType(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy; + if (auto carrierTy = dyn_cast(type)) + return carrierTy.getAccumulatorType(); + return RankedTensorType(); +} + +static bool isSqmmaAccumulatorCarrierType(Type type) { + return isa(type); +} + +} // namespace + +namespace mlir::triton::MUSA { + +static InputPrecision decodeInputPrecisionAttr(int64_t value) { + switch (value) { + case 0: + return InputPrecision::IEEE; + case 1: + return InputPrecision::TF32; + case 2: + return InputPrecision::TF32x3; + case 3: + case 4: + llvm::report_fatal_error( + "BF16x3/BF16x6 must be rewritten by TritonGPUF32DotTC before " + "MUSA SQMMA lowering"); + default: + llvm::report_fatal_error("Unexpected MUSA dot input precision attribute"); + } +} + +static Value getDotOperandA(triton::DotOp op) { return op.getA(); } +static Value getDotOperandA(triton::musa::SquadDotOp op) { return op.getA(); } +static Value getDotOperandA(triton::mtgpu::SqmmaOp op) { return op.getA(); } +static Value getDotOperandB(triton::DotOp op) { return op.getB(); } +static Value getDotOperandB(triton::musa::SquadDotOp op) { return op.getB(); } +static Value getDotOperandB(triton::mtgpu::SqmmaOp op) { return op.getB(); } +static Value getDotOperandC(triton::DotOp op) { return op.getC(); } +static Value getDotOperandC(triton::musa::SquadDotOp op) { return op.getC(); } +static Value getDotOperandC(triton::mtgpu::SqmmaOp op) { return op.getC(); } + +static Value getDotAdaptorA(triton::DotOp::Adaptor adaptor) { + return adaptor.getA(); +} +static Value getDotAdaptorA(triton::musa::SquadDotOp::Adaptor adaptor) { + return adaptor.getA(); +} +static Value getDotAdaptorA(triton::mtgpu::SqmmaOp::Adaptor adaptor) { + return adaptor.getA(); +} +static Value getDotAdaptorB(triton::DotOp::Adaptor adaptor) { + return adaptor.getB(); +} +static Value getDotAdaptorB(triton::musa::SquadDotOp::Adaptor adaptor) { + return adaptor.getB(); +} +static Value getDotAdaptorB(triton::mtgpu::SqmmaOp::Adaptor adaptor) { + return adaptor.getB(); +} +static Value getDotAdaptorC(triton::DotOp::Adaptor adaptor) { + return adaptor.getC(); +} +static Value getDotAdaptorC(triton::musa::SquadDotOp::Adaptor adaptor) { + return adaptor.getC(); +} +static Value getDotAdaptorC(triton::mtgpu::SqmmaOp::Adaptor adaptor) { + return adaptor.getC(); +} + +static Value getDotUseCValue(triton::DotOp, triton::DotOp::Adaptor) { + return Value(); +} +static Value getDotUseCValue(triton::musa::SquadDotOp, + triton::musa::SquadDotOp::Adaptor adaptor) { + return adaptor.getUseC(); +} +static Value getDotUseCValue(triton::mtgpu::SqmmaOp, + triton::mtgpu::SqmmaOp::Adaptor adaptor) { + return adaptor.getUseC(); +} + +static InputPrecision getDotInputPrecision(triton::DotOp op) { + return op.getInputPrecision(); +} +static InputPrecision getDotInputPrecision(triton::musa::SquadDotOp op) { + return decodeInputPrecisionAttr(op.getInputPrecision()); +} +static InputPrecision getDotInputPrecision(triton::mtgpu::SqmmaOp op) { + return decodeInputPrecisionAttr(op.getInputPrecision()); +} + +static bool getDotIsAsync(triton::DotOp op) { return false; } +static bool getDotIsAsync(triton::musa::SquadDotOp op) { + return op.getIsAsync(); +} +static bool getDotIsAsync(triton::mtgpu::SqmmaOp op) { return op.getIsAsync(); } + +static triton::musa::SQMMAEltType getDotEltTypeA(triton::DotOp op) { + bool allowTF32 = getDotInputPrecision(op) == InputPrecision::TF32; + return getMmaOperandType(op.getA(), allowTF32); +} +static triton::musa::SQMMAEltType getDotEltTypeA(triton::musa::SquadDotOp op) { + return op.getEltTypeA(); +} +static triton::musa::SQMMAEltType getDotEltTypeA(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getEltTypeA())); +} + +static triton::musa::SQMMAEltType getDotEltTypeB(triton::DotOp op) { + bool allowTF32 = getDotInputPrecision(op) == InputPrecision::TF32; + return getMmaOperandType(op.getB(), allowTF32); +} +static triton::musa::SQMMAEltType getDotEltTypeB(triton::musa::SquadDotOp op) { + return op.getEltTypeB(); +} +static triton::musa::SQMMAEltType getDotEltTypeB(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getEltTypeB())); +} + +static triton::musa::SQMMAEltType getDotEltTypeC(triton::DotOp op) { + return getMmaRetType(op.getResult()); +} +static triton::musa::SQMMAEltType getDotEltTypeC(triton::musa::SquadDotOp op) { + return op.getEltTypeC(); +} +static triton::musa::SQMMAEltType getDotEltTypeC(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getEltTypeC())); +} + +static uint32_t getDotMaxNumImpreciseAcc(triton::DotOp op) { + return op.getMaxNumImpreciseAcc(); +} +static uint32_t getDotMaxNumImpreciseAcc(triton::musa::SquadDotOp op) { + return static_cast( + std::max(0, op.getMaxNumImpreciseAcc())); +} +static uint32_t getDotMaxNumImpreciseAcc(triton::mtgpu::SqmmaOp op) { + return static_cast( + std::max(0, op.getMaxNumImpreciseAcc())); +} + +static triton::musa::SQMMALayout getDotLayoutA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + if (auto memDescTy = dyn_cast(aTy)) + return triton::musa::inferSharedRowMajor(memDescTy) + ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; + auto tensorTy = cast(aTy); + auto order = getOrderForMemory(tensorTy); + bool isRowMajor = !order.empty() && order.front() + 1 == tensorTy.getRank(); + return isRowMajor ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; +} +static triton::musa::SQMMALayout getDotLayoutA(triton::musa::SquadDotOp op) { + return op.getLayoutA(); +} +static triton::musa::SQMMALayout getDotLayoutA(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getLayoutA())); +} +static triton::musa::SQMMALayout getDotLayoutB(triton::DotOp op) { + auto bTy = cast(op.getB().getType()); + if (auto memDescTy = dyn_cast(bTy)) + return triton::musa::inferSharedRowMajor(memDescTy) + ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; + auto tensorTy = cast(bTy); + auto order = getOrderForMemory(tensorTy); + bool isRowMajor = !order.empty() && order.front() + 1 == tensorTy.getRank(); + return isRowMajor ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; +} +static triton::musa::SQMMALayout getDotLayoutB(triton::musa::SquadDotOp op) { + return op.getLayoutB(); +} +static triton::musa::SQMMALayout getDotLayoutB(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getLayoutB())); +} + +static triton::musa::SQMMAAccumulationMode +getDotAccumulationMode(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + bool fp8 = isFP8(getDotEltTypeA(op)); + bool accFP32 = + cast(op.getResult().getType()).getElementType().isF32(); + uint32_t maxNumImpreciseAcc = getDotMaxNumImpreciseAcc(op); + if (fp8 && accFP32 && maxNumImpreciseAcc > 0 && + maxNumImpreciseAcc <= aTy.getShape().back()) + return triton::musa::SQMMAAccumulationMode::partial; + return triton::musa::SQMMAAccumulationMode::hardware; +} +static triton::musa::SQMMAAccumulationMode +getDotAccumulationMode(triton::musa::SquadDotOp op) { + return op.getAccMode(); +} +static triton::musa::SQMMAAccumulationMode +getDotAccumulationMode(triton::mtgpu::SqmmaOp op) { + return static_cast( + static_cast(op.getAccMode())); +} + +template +LogicalResult convertSQMMADotImpl(DotLikeOp op, DotLikeAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value threadId) { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto dTy = getSqmmaAccumulatorTensorType(op.getResult().getType()); + bool carrierMode = isSqmmaAccumulatorCarrierType(op.getResult().getType()); + if (!dTy) + return op.emitError("MUSA SQMMA: expected tensor or SQMMA carrier result"); + auto mmaEnc = dyn_cast(dTy.getEncoding()); + if (!mmaEnc) + return op.emitError("MUSA SQMMA: expected #ttg.musa_sqmma result encoding"); + if (!mmaEnc.isPH1()) + return op.emitError("MUSA SQMMA: unsupported version"); + + if (dTy.getRank() != 2) + return op.emitError("MUSA SQMMA supports rank-2 tensors only"); + + auto instrShape = mmaEnc.getInstrShape(); + if (instrShape.size() != 3) + return op.emitError("MUSA SQMMA expects 3D instrShape"); + + const unsigned instM = instrShape[0]; + const unsigned instN = instrShape[1]; + const unsigned instK = instrShape[2]; + + Value opA = getDotOperandA(op); + Value opB = getDotOperandB(op); + Value opC = getDotOperandC(op); + Value useCFlag = triton::musa::materializeUseCFlag( + loc, getDotUseCValue(op, adaptor), rewriter); + std::optional useCConst = getBoolFromConstant(useCFlag); + + Value adaptorA = getDotAdaptorA(adaptor); + Value adaptorB = getDotAdaptorB(adaptor); + Value adaptorC = getDotAdaptorC(adaptor); + auto aOperand = triton::musa::resolveSharedOperandWithAffineBase( + opA, adaptorA, loc, typeConverter, rewriter); + auto bOperand = triton::musa::resolveSharedOperandWithAffineBase( + opB, adaptorB, loc, typeConverter, rewriter); + if (failed(aOperand)) + return op.emitError( + "MUSA SQMMA requires operand A from ttg.local_load(shared memdesc)"); + if (failed(bOperand)) + return op.emitError( + "MUSA SQMMA requires operand B from ttg.local_load(shared memdesc)"); + + auto aMemTy = aOperand->memDescTy; + auto bMemTy = bOperand->memDescTy; + + if (!aMemTy || !triton::musa::isSharedEncoding(aMemTy.getEncoding())) + return op.emitError("MUSA SQMMA requires operand A in shared memory"); + if (!bMemTy || !triton::musa::isSharedEncoding(bMemTy.getEncoding())) + return op.emitError("MUSA SQMMA requires operand B in shared memory"); + if (aMemTy.getRank() != 2 || bMemTy.getRank() != 2) + return op.emitError("MUSA SQMMA supports rank-2 operands only"); + + auto eltTypeC = getDotEltTypeC(op); + auto eltTypeA = getDotEltTypeA(op); + auto eltTypeB = getDotEltTypeB(op); + + if (!isSupportedSqmma(eltTypeA, eltTypeB, eltTypeC, instM, instN, instK)) + return op.emitError("MUSA SQMMA: unsupported shape or element type"); + std::string sqmmaIntrinsic = + triton::musa::lookupSqmmaIntrinsic(eltTypeA, instM, instN, instK); + if (sqmmaIntrinsic.empty()) + return op.emitError("MUSA SQMMA: unsupported operand type"); + + auto warpsPerCTA = mmaEnc.getWarpsPerCTA(); + if (warpsPerCTA.size() < 2 || warpsPerCTA[0] % 4 != 0) + return op.emitError("MUSA SQMMA: invalid warpsPerCTA"); + + auto dShapePerCTA = getShapePerCTA(dTy); + unsigned blockM = dShapePerCTA[0]; + unsigned blockN = dShapePerCTA[1]; + unsigned squadsM = warpsPerCTA[0] / 4; + unsigned squadsN = warpsPerCTA[1]; + unsigned tileM = instM * squadsM; + unsigned tileN = instN * squadsN; + + auto ceilDiv = [](unsigned x, unsigned y) { return (x + y - 1) / y; }; + unsigned numRepM = ceilDiv(blockM, tileM); + unsigned numRepN = ceilDiv(blockN, tileN); + int64_t kDimVal = aMemTy.getShape().back(); + if (kDimVal <= 0) + return op.emitError("MUSA SQMMA requires static positive K dimension"); + unsigned kDim = static_cast(kDimVal); + unsigned numRepK = std::max(1u, ceilDiv(kDim, instK)); + + unsigned repCount = numRepM * numRepN; + unsigned totalAccElems = mmaEnc.getTotalElemsPerThread(dTy.getShape()); + if (repCount == 0 || totalAccElems == 0 || (totalAccElems % repCount) != 0) + return op.emitError("MUSA SQMMA: invalid accumulator partitioning"); + unsigned accElemsPerThread = totalAccElems / repCount; + constexpr unsigned warpSize = 32; + + bool zeroAcc = isZeroConst(opC); + SmallVector fc; + SmallVector fcFragments; + if (zeroAcc) { + if (carrierMode) { + fcFragments.resize(repCount); + } else { + Type accElemTy = typeConverter->convertType(dTy.getElementType()); + Value zero = LLVM::ZeroOp::create(rewriter, loc, accElemTy); + size_t totalAcc = + static_cast(accElemsPerThread) * numRepM * numRepN; + fc.assign(totalAcc, zero); + } + } else { + if (carrierMode) { + fcFragments = mlir::LLVM::MUSA::unpackSqmmaAccumulatorCarrier( + loc, adaptorC, dTy, rewriter); + } else { + fc = ::mlir::unpackLLElements(loc, adaptorC, rewriter); + } + } + + size_t expectedAccElems = + static_cast(accElemsPerThread) * numRepM * numRepN; + if (!carrierMode && fc.size() != expectedAccElems) { + if (fc.empty() || (fc.size() % expectedAccElems) != 0) + return op.emitError("MUSA SQMMA: accumulator size mismatch"); + size_t dupFactor = fc.size() / expectedAccElems; + SmallVector compact; + compact.reserve(expectedAccElems); + for (size_t i = 0; i < expectedAccElems; ++i) + compact.push_back(fc[i * dupFactor]); + fc.swap(compact); + } + + Value warp = b.udiv(threadId, b.i32_val(warpSize)); + Value warpGroup = b.and_(warp, b.i32_val(0xFFFFFFFC)); + Value squad = b.udiv(warpGroup, b.i32_val(4)); + Value squadM = b.urem(squad, b.i32_val(squadsM)); + Value squadMN = b.udiv(squad, b.i32_val(squadsM)); + Value squadN = b.urem(squadMN, b.i32_val(squadsN)); + + auto layoutA = getDotLayoutA(op); + auto layoutB = getDotLayoutB(op); + bool loaderTransA = layoutA == triton::musa::SQMMALayout::col; + bool loaderTransB = layoutB == triton::musa::SQMMALayout::row; + int32_t intrinsicTransA = (layoutA == triton::musa::SQMMALayout::col) ? 1 : 0; + int32_t intrinsicTransB = (layoutB == triton::musa::SQMMALayout::col) ? 1 : 0; + + auto aDescSpec = buildSqmmaDescriptorIntrinsicSpec( + loc, aMemTy, aOperand->physicalShape, eltTypeA, typeConverter, rewriter); + if (failed(aDescSpec)) + return op.emitError("MUSA SQMMA failed to derive descriptor contract for " + "operand A"); + auto bDescSpec = buildSqmmaDescriptorIntrinsicSpec( + loc, bMemTy, bOperand->physicalShape, eltTypeB, typeConverter, rewriter); + if (failed(bDescSpec)) + return op.emitError("MUSA SQMMA failed to derive descriptor contract for " + "operand B"); + + SqmmaSmemLoader aLoader(aOperand->memDesc, aOperand->affineBase, + aMemTy.getShape(), aOperand->physicalShape, squadM, + squadsM, loaderTransA, instM, instK, *aDescSpec, + rewriter, loc); + SqmmaSmemLoader bLoader(bOperand->memDesc, bOperand->affineBase, + bMemTy.getShape(), bOperand->physicalShape, squadN, + squadsN, loaderTransB, instN, instK, *bDescSpec, + rewriter, loc); + + auto accumulationMode = getDotAccumulationMode(op); + uint32_t maxNumImpreciseAcc = getDotMaxNumImpreciseAcc(op); + bool usesSoftwareAccumulator = + accumulationMode == triton::musa::SQMMAAccumulationMode::software; + bool needsPartialAccumulator = + accumulationMode == triton::musa::SQMMAAccumulationMode::partial; + Type accLaneTy = typeConverter->convertType(dTy.getElementType()); + Type accVecTy = vec_ty(accLaneTy, accElemsPerThread); + + SmallVector mmaResults; + SmallVector mmaCarrierFragments; + for (unsigned mRep = 0; mRep < numRepM; ++mRep) { + for (unsigned nRep = 0; nRep < numRepN; ++nRep) { + size_t accBase = (mRep * numRepN + nRep) * accElemsPerThread; + unsigned fragmentIdx = mRep * numRepN + nRep; + auto ivecTy = vec_ty(IntegerType::get(rewriter.getContext(), 32), + accElemsPerThread); + if (usesSoftwareAccumulator) { + SmallVector accumElems; + if (zeroAcc) { + if (carrierMode) { + Value zeroVec = LLVM::ZeroOp::create(rewriter, loc, accVecTy); + accumElems = unpackLLVector(loc, zeroVec, rewriter); + } else { + accumElems.append(fc.begin() + accBase, + fc.begin() + accBase + accElemsPerThread); + } + } else { + Value accSliceVec; + if (carrierMode) { + accSliceVec = mlir::LLVM::MUSA::carrierFragmentToMathVec( + loc, fcFragments[fragmentIdx], dTy, rewriter); + } else { + auto accSlice = loadAccSlice(rewriter, loc, fc, accBase, + accElemsPerThread, nullptr); + accSliceVec = packLLVector(loc, accSlice, rewriter); + } + // In the software-accumulation family, useC=false keeps the 3.2 + // contract of "hardware sees zero C, outer fadd still preserves the + // running accumulator". Only dynamic first-use flags introduced by + // accumulator-init rewriting should zero the carried accumulator + // here. + if (!useCConst) { + Value zeroSliceVec = + LLVM::ZeroOp::create(rewriter, loc, accSliceVec.getType()); + accSliceVec = selectAccumulatorVector(loc, useCFlag, accSliceVec, + zeroSliceVec, rewriter); + } + accumElems = unpackLLVector(loc, accSliceVec, rewriter); + } + for (unsigned kRep = 0; kRep < numRepK; ++kRep) { + Value opA = aLoader.smemDesc(mRep, kRep, rewriter, loc); + Value opB = bLoader.smemDesc(nRep, kRep, rewriter, loc); + + SmallVector args = { + opA, + opB, + LLVM::ZeroOp::create(rewriter, loc, ivecTy), + b.i32_val(intrinsicTransA), + b.i32_val(intrinsicTransB), + b.i32_val(0), // aNeg + b.i32_val(0), // bNeg + b.i32_val(1), // scale_out + b.i32_val(0), // sat + b.i32_val(0), // stepA + b.i32_val(0), // stepB + b.i32_val(0) // stepC + }; + auto call = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, sqmmaIntrinsic, TypeRange{ivecTy}, args); + Value newAcc = call.getResult(0); + Value newMathVec = newAcc.getType() == accVecTy + ? newAcc + : b.bitcast(newAcc, accVecTy); + SmallVector newElems = + unpackLLVector(loc, newMathVec, rewriter); + if (accumElems.empty()) { + accumElems = std::move(newElems); + } else { + for (unsigned i = 0; i < newElems.size(); ++i) { + accumElems[i] = + addAccumulatorLane(rewriter, loc, accumElems[i], newElems[i]); + } + } + } + if (carrierMode) { + Value accumVec = packLLVector(loc, accumElems, rewriter); + fcFragments[fragmentIdx] = mlir::LLVM::MUSA::mathVecToCarrierFragment( + loc, accumVec, dTy, rewriter); + } else { + for (unsigned i = 0; i < accumElems.size(); ++i) + fc[accBase + i] = accumElems[i]; + } + continue; + } + + Type accElemTy = accLaneTy; + Value accVec; + if (carrierMode) { + accVec = zeroAcc ? LLVM::ZeroOp::create(rewriter, loc, ivecTy) + : fcFragments[fragmentIdx]; + if (!zeroAcc && (!useCConst || !*useCConst)) { + Value zeroVec = LLVM::ZeroOp::create(rewriter, loc, accVec.getType()); + accVec = + selectAccumulatorVector(loc, useCFlag, accVec, zeroVec, rewriter); + } + } else if (zeroAcc) { + accVec = LLVM::ZeroOp::create(rewriter, loc, + vec_ty(accElemTy, accElemsPerThread)); + } else { + auto accSlice = loadAccSlice(rewriter, loc, fc, accBase, + accElemsPerThread, nullptr); + if (!accSlice.empty()) + accElemTy = accSlice.front().getType(); + accVec = packLLVector(loc, accSlice, rewriter); + if (!useCConst || !*useCConst) { + Value zeroVec = LLVM::ZeroOp::create(rewriter, loc, accVec.getType()); + accVec = + selectAccumulatorVector(loc, useCFlag, accVec, zeroVec, rewriter); + } + } + auto currentAccVecTy = accVec.getType(); + Value vecAcc = + accVec.getType() == ivecTy ? accVec : b.bitcast(accVec, ivecTy); + + Value dAcc = zeroAcc ? Value() : vecAcc; + + uint32_t numLowPrecAcc = 0; + Value partialAcc; + for (unsigned kRep = 0; kRep < numRepK; ++kRep) { + Value opA = aLoader.smemDesc(mRep, kRep, rewriter, loc); + Value opB = bLoader.smemDesc(nRep, kRep, rewriter, loc); + numLowPrecAcc += instK; + bool requireAdd = + needsPartialAccumulator && + (numLowPrecAcc >= maxNumImpreciseAcc || kRep == numRepK - 1); + Value mmaAcc = needsPartialAccumulator ? partialAcc : dAcc; + Value inputAcc = + mmaAcc ? mmaAcc : LLVM::ZeroOp::create(rewriter, loc, ivecTy); + + SmallVector args = { + opA, + opB, + inputAcc, + b.i32_val(intrinsicTransA), + b.i32_val(intrinsicTransB), + b.i32_val(0), // aNeg + b.i32_val(0), // bNeg + b.i32_val(1), // scale_out + b.i32_val(0), // sat + b.i32_val(0), // stepA + b.i32_val(0), // stepB + b.i32_val(0) // stepC + }; + auto call = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, sqmmaIntrinsic, TypeRange{ivecTy}, args); + Value newAcc = call.getResult(0); + if (needsPartialAccumulator) { + partialAcc = newAcc; + } else { + dAcc = newAcc; + } + + if (requireAdd) { + Value partialFloat = partialAcc.getType() == accVecTy + ? partialAcc + : b.bitcast(partialAcc, accVecTy); + if (dAcc) { + Value dFloat = + dAcc.getType() == accVecTy ? dAcc : b.bitcast(dAcc, accVecTy); + dFloat = addAccumulate(rewriter, loc, dFloat, partialFloat); + dAcc = + dFloat.getType() == ivecTy ? dFloat : b.bitcast(dFloat, ivecTy); + } else { + dAcc = partialAcc; + } + numLowPrecAcc = 0; + partialAcc = Value(); + } + } + + if (needsPartialAccumulator && partialAcc) { + Value partialFloat = partialAcc.getType() == accVecTy + ? partialAcc + : b.bitcast(partialAcc, accVecTy); + if (dAcc) { + Value dFloat = + dAcc.getType() == accVecTy ? dAcc : b.bitcast(dAcc, accVecTy); + dFloat = addAccumulate(rewriter, loc, dFloat, partialFloat); + dAcc = + dFloat.getType() == ivecTy ? dFloat : b.bitcast(dFloat, ivecTy); + } else { + dAcc = partialAcc; + } + } + + Value finalReg = dAcc ? dAcc : vecAcc; + if (carrierMode) { + Value finalMathVec = finalReg.getType() == accVecTy + ? finalReg + : b.bitcast(finalReg, accVecTy); + Value carrierFragment = mlir::LLVM::MUSA::mathVecToCarrierFragment( + loc, finalMathVec, dTy, rewriter); + mmaCarrierFragments.push_back(carrierFragment); + } else { + Value finalAcc = finalReg.getType() == currentAccVecTy + ? finalReg + : b.bitcast(finalReg, currentAccVecTy); + SmallVector accElems = unpackLLVector(loc, finalAcc, rewriter); + for (unsigned i = 0; i < accElems.size(); ++i) + mmaResults.push_back(accElems[i]); + } + } + } + + if (carrierMode) { + if (!getDotIsAsync(op)) { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.sqmma.wait", + TypeRange{}, {}); + } + Value res = mlir::LLVM::MUSA::packSqmmaAccumulatorCarrier( + loc, + usesSoftwareAccumulator ? ValueRange(fcFragments) + : ValueRange(mmaCarrierFragments), + dTy, rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + if (usesSoftwareAccumulator) + mmaResults.assign(fc.begin(), fc.end()); + + unsigned encodedAccElems = mmaEnc.getTotalElemsPerThread(dTy.getShape()); + if (mmaResults.size() != encodedAccElems) + return op.emitError("MUSA SQMMA: result accumulator size mismatch"); + + if (!getDotIsAsync(op)) { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.sqmma.wait", + TypeRange{}, {}); + } + + Value res = + ::mlir::packLLElements(loc, typeConverter, mmaResults, rewriter, dTy); + + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult convertSQMMADot(triton::musa::SquadDotOp op, + triton::musa::SquadDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value threadId) { + return convertSQMMADotImpl(op, adaptor, typeConverter, rewriter, threadId); +} + +LogicalResult convertSQMMADot(triton::mtgpu::SqmmaOp op, + triton::mtgpu::SqmmaOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value threadId) { + return convertSQMMADotImpl(op, adaptor, typeConverter, rewriter, threadId); +} + +} // namespace mlir::triton::MUSA diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/WMMA.cpp new file mode 100644 index 0000000000..16fc80e54c --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -0,0 +1,638 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "DotOpToLLVM.h" +#include "TritonMUSACommon/MMAContractUtils.h" +#include "TritonMUSACommon/MMAEncodingUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +static SmallVector +packElemsToI32(TritonLLVMOpBuilder &b, Location loc, ArrayRef elems, + Type elemTy, const LLVMTypeConverter *typeConverter) { + Type packElemTy = elemTy; + if (!elems.empty()) + packElemTy = elems.front().getType(); + else if (typeConverter) { + Type converted = typeConverter->convertType(elemTy); + if (converted) + packElemTy = converted; + } + + const unsigned bitwidth = packElemTy.getIntOrFloatBitWidth(); + Type i32Ty = b.builder->getI32Type(); + SmallVector packed; + if (bitwidth == 32) { + packed.reserve(elems.size()); + for (Value v : elems) { + Value normalized = + (v.getType() == packElemTy) ? v : b.bitcast(v, packElemTy); + Value i32Val = normalized.getType().isInteger(32) + ? normalized + : b.bitcast(normalized, i32Ty); + packed.push_back(i32Val); + } + return packed; + } + + if (bitwidth == 16) { + if (elems.size() % 2 != 0) + llvm::report_fatal_error("WMMA: expected even number of 16-bit elements"); + Type vecTy = vec_ty(packElemTy, 2); + for (size_t i = 0; i < elems.size(); i += 2) { + Value pack = b.undef(vecTy); + Value e0 = (elems[i].getType() == packElemTy) + ? elems[i] + : b.bitcast(elems[i], packElemTy); + Value e1 = (elems[i + 1].getType() == packElemTy) + ? elems[i + 1] + : b.bitcast(elems[i + 1], packElemTy); + pack = b.insert_element(vecTy, pack, e0, b.i32_val(0)); + pack = b.insert_element(vecTy, pack, e1, b.i32_val(1)); + packed.push_back(b.bitcast(pack, i32Ty)); + } + return packed; + } + + if (bitwidth == 8) { + if (elems.size() % 4 != 0) + llvm::report_fatal_error("WMMA: expected 4x8-bit packing"); + Type vecTy = vec_ty(packElemTy, 4); + for (size_t i = 0; i < elems.size(); i += 4) { + Value pack = b.undef(vecTy); + Value e0 = (elems[i].getType() == packElemTy) + ? elems[i] + : b.bitcast(elems[i], packElemTy); + Value e1 = (elems[i + 1].getType() == packElemTy) + ? elems[i + 1] + : b.bitcast(elems[i + 1], packElemTy); + Value e2 = (elems[i + 2].getType() == packElemTy) + ? elems[i + 2] + : b.bitcast(elems[i + 2], packElemTy); + Value e3 = (elems[i + 3].getType() == packElemTy) + ? elems[i + 3] + : b.bitcast(elems[i + 3], packElemTy); + pack = b.insert_element(vecTy, pack, e0, b.i32_val(0)); + pack = b.insert_element(vecTy, pack, e1, b.i32_val(1)); + pack = b.insert_element(vecTy, pack, e2, b.i32_val(2)); + pack = b.insert_element(vecTy, pack, e3, b.i32_val(3)); + packed.push_back(b.bitcast(pack, i32Ty)); + } + return packed; + } + + llvm::report_fatal_error("WMMA: unsupported element width for packing"); +} + +static Value packToVector(TritonLLVMOpBuilder &b, Location loc, Type elemTy, + ArrayRef elems) { + if (elems.size() == 1) + return elems.front(); + Type vecTy = vec_ty(elemTy, elems.size()); + Value vec = b.undef(vecTy); + for (auto [i, v] : llvm::enumerate(elems)) + vec = b.insert_element(vecTy, vec, v, b.i32_val(i)); + return vec; +} + +static SmallVector unpackI32ToElems(TritonLLVMOpBuilder &b, Location loc, + ArrayRef packed, + Type dstElemTy) { + const unsigned bitwidth = dstElemTy.getIntOrFloatBitWidth(); + SmallVector elems; + if (bitwidth == 32) { + elems.reserve(packed.size()); + for (Value v : packed) { + elems.push_back(v.getType().isInteger(32) ? b.bitcast(v, dstElemTy) + : b.bitcast(v, dstElemTy)); + } + return elems; + } + if (bitwidth == 16) { + Type vecTy = vec_ty(dstElemTy, 2); + elems.reserve(packed.size() * 2); + for (Value v : packed) { + Value vec = b.bitcast(v, vecTy); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(0))); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(1))); + } + return elems; + } + if (bitwidth == 8) { + Type vecTy = vec_ty(dstElemTy, 4); + elems.reserve(packed.size() * 4); + for (Value v : packed) { + Value vec = b.bitcast(v, vecTy); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(0))); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(1))); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(2))); + elems.push_back(b.extract_element(dstElemTy, vec, b.i32_val(3))); + } + return elems; + } + llvm::report_fatal_error("WMMA: unsupported element width for unpacking"); +} + +static LinearLayout buildPH1WMMATileLayout(MLIRContext *ctx, unsigned rank, + unsigned instM, unsigned instN) { + auto outDimNames = standardOutDimNames(ctx, rank); + bool hasBatch = rank == 3; + StringAttr dimM = outDimNames[hasBatch ? 1 : 0]; + StringAttr dimN = outDimNames[hasBatch ? 2 : 1]; + + LinearLayout tileLayout( + {{str_attr("register"), {}}, + {str_attr("lane"), {{0, 1}, {0, 2}, {0, 4}, {1, 0}, {2, 0}}}, + {str_attr("warp"), {}}, + {str_attr("block"), {}}}, + {dimM, dimN}); + + tileLayout *= LinearLayout::identity1D(instN / 8, str_attr("register"), dimN); + tileLayout *= LinearLayout::identity1D(instM / 4, str_attr("register"), dimM); + + if (hasBatch) { + tileLayout *= + LinearLayout::identity1D(1, str_attr("register"), outDimNames[0]); + tileLayout *= LinearLayout::identity1D(1, str_attr("lane"), outDimNames[0]); + } + + return tileLayout; +} + +} // namespace + +namespace mlir::triton::MUSA { + +static Value getWmmaOperandA(triton::DotOp op) { return op.getA(); } +static Value getWmmaOperandA(triton::musa::WmmaDotOp op) { return op.getA(); } +static Value getWmmaOperandB(triton::DotOp op) { return op.getB(); } +static Value getWmmaOperandB(triton::musa::WmmaDotOp op) { return op.getB(); } + +static Value getWmmaAdaptorA(triton::DotOp::Adaptor adaptor) { + return adaptor.getA(); +} +static Value getWmmaAdaptorA(triton::musa::WmmaDotOp::Adaptor adaptor) { + return adaptor.getA(); +} +static Value getWmmaAdaptorB(triton::DotOp::Adaptor adaptor) { + return adaptor.getB(); +} +static Value getWmmaAdaptorB(triton::musa::WmmaDotOp::Adaptor adaptor) { + return adaptor.getB(); +} +static Value getWmmaAdaptorC(triton::DotOp::Adaptor adaptor) { + return adaptor.getC(); +} +static Value getWmmaAdaptorC(triton::musa::WmmaDotOp::Adaptor adaptor) { + return adaptor.getC(); +} + +static Value getWmmaUseCValue(triton::DotOp, triton::DotOp::Adaptor) { + return Value(); +} +static Value getWmmaUseCValue(triton::musa::WmmaDotOp, + triton::musa::WmmaDotOp::Adaptor adaptor) { + return adaptor.getUseC(); +} + +static triton::musa::SQMMALayout getWmmaLayoutA(triton::DotOp op) { + return triton::musa::inferWmmaFragmentLayout(op.getA(), 0); +} +static triton::musa::SQMMALayout getWmmaLayoutA(triton::musa::WmmaDotOp op) { + return op.getLayoutA(); +} +static triton::musa::SQMMALayout getWmmaLayoutB(triton::DotOp op) { + return triton::musa::inferWmmaFragmentLayout(op.getB(), 1); +} +static triton::musa::SQMMALayout getWmmaLayoutB(triton::musa::WmmaDotOp op) { + return op.getLayoutB(); +} + +static Value materializeUseCFlag(Location loc, Value useC, + ConversionPatternRewriter &rewriter) { + return triton::musa::materializeUseCFlag(loc, useC, rewriter); +} + +static Value selectAccumulatorVector(Location loc, Value useC, Value acc, + Value zero, + ConversionPatternRewriter &rewriter) { + return triton::musa::selectAccumulatorValue(loc, useC, acc, zero, rewriter); +} + +struct WmmaTileState { + unsigned regBase; + int batch; + int mBase; + int nBase; +}; + +struct WmmaTilePlan { + SmallVector tiles; + unsigned numRepK; + unsigned numRepN; +}; + +struct ContiguousRank2WmmaFastPathState { + WmmaTilePlan tilePlan; + SmallVector aElemsAll; + SmallVector bElemsAll; +}; + +static std::optional +buildContiguousRank2WmmaFastPathState( + Location loc, Value adaptorA, Value adaptorB, RankedTensorType aTy, + RankedTensorType dTy, const triton::musa::WmmaDotOperandContract &aContract, + const triton::musa::WmmaDotOperandContract &bContract, + triton::musa::SQMMALayout layoutA, triton::musa::SQMMALayout layoutB, + unsigned instM, unsigned instN, unsigned instK, unsigned warpSize, + ConversionPatternRewriter &rewriter) { + if (dTy.getRank() != 2) + return std::nullopt; + if (layoutA != triton::musa::SQMMALayout::row || + layoutB != triton::musa::SQMMALayout::col) + return std::nullopt; + if (aContract.rank != 2 || bContract.rank != 2) + return std::nullopt; + if (aContract.dotEncoding.getParent() != dTy.getEncoding() || + bContract.dotEncoding.getParent() != dTy.getEncoding()) + return std::nullopt; + + const unsigned aElemsPerThread = (instM * instK) / warpSize; + const unsigned bElemsPerThread = (instN * instK) / warpSize; + if (aElemsPerThread == 0 || bElemsPerThread == 0) + return std::nullopt; + + ContiguousRank2WmmaFastPathState state; + state.tilePlan.numRepK = std::max( + 1u, ceil(static_cast(aTy.getShape().back()), instK)); + state.aElemsAll = ::mlir::unpackLLElements(loc, adaptorA, rewriter); + state.bElemsAll = ::mlir::unpackLLElements(loc, adaptorB, rewriter); + if (state.aElemsAll.size() % aElemsPerThread != 0 || + state.bElemsAll.size() % bElemsPerThread != 0) + return std::nullopt; + unsigned aChunks = state.aElemsAll.size() / aElemsPerThread; + unsigned bChunks = state.bElemsAll.size() / bElemsPerThread; + if (aChunks % state.tilePlan.numRepK != 0 || + bChunks % state.tilePlan.numRepK != 0) + return std::nullopt; + unsigned numRepM = aChunks / state.tilePlan.numRepK; + state.tilePlan.numRepN = bChunks / state.tilePlan.numRepK; + if (numRepM == 0 || state.tilePlan.numRepN == 0) + return std::nullopt; + state.tilePlan.tiles.reserve(numRepM * state.tilePlan.numRepN); + for (unsigned mTile = 0; mTile < numRepM; ++mTile) { + for (unsigned nTile = 0; nTile < state.tilePlan.numRepN; ++nTile) { + unsigned cTileIdx = mTile * state.tilePlan.numRepN + nTile; + state.tilePlan.tiles.push_back({cTileIdx, 0, + static_cast(mTile * instM), + static_cast(nTile * instN)}); + } + } + + return state; +} + +static FailureOr +buildGenericWmmaTilePlan(MLIRContext *ctx, const LinearLayout &cLinearLayout, + unsigned rank, unsigned instM, unsigned instN, + unsigned cElemsPerThread, unsigned fcSize) { + auto tileLayout = buildPH1WMMATileLayout(ctx, rank, instM, instN); + auto quot = divideLeft(cLinearLayout, tileLayout); + if (!quot) + return failure(); + auto repLayout = zerosLike(tileLayout) * *quot; + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + unsigned repRegs = repLayout.getInDimSize(kRegister); + if (repRegs != fcSize) + return failure(); + + WmmaTilePlan plan; + plan.numRepK = 0; + plan.numRepN = 0; + plan.tiles.reserve(repRegs / cElemsPerThread); + for (unsigned regBase = 0; regBase < repRegs; regBase += cElemsPerThread) { + SmallVector, 4> repCoords = { + {kRegister, static_cast(regBase)}, + {kLane, 0}, + {kWarp, 0}, + {kBlock, 0}}; + auto coords = repLayout.apply(repCoords); + int batch = 0; + int mBase = 0; + int nBase = 0; + if (rank == 3) { + batch = coords[0].second; + mBase = coords[1].second; + nBase = coords[2].second; + } else { + mBase = coords[0].second; + nBase = coords[1].second; + } + plan.tiles.push_back({regBase, batch, mBase, nBase}); + } + return plan; +} + +static SmallVector extractContiguousRank2WmmaOperandChunk( + Location loc, ArrayRef elemsAll, unsigned chunkIdx, + unsigned elemsPerThread, unsigned validElems, Type elemTy, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + SmallVector elems; + elems.reserve(elemsPerThread); + Type llvmElemTy = typeConverter->convertType(elemTy); + Value zero = LLVM::ZeroOp::create(rewriter, loc, llvmElemTy); + unsigned base = chunkIdx * elemsPerThread; + for (unsigned i = 0; i < elemsPerThread; ++i) { + if (i < validElems && base + i < elemsAll.size()) + elems.push_back(elemsAll[base + i]); + else + elems.push_back(zero); + } + return elems; +} + +static std::optional buildContiguousRank2PackedWmmaOperandA( + Location loc, const ContiguousRank2WmmaFastPathState &state, unsigned mBase, + unsigned kTile, unsigned instM, unsigned validK, unsigned warpSize, + unsigned elemsPerThread, Type elemTy, TritonLLVMOpBuilder &b, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (mBase % instM != 0) + return std::nullopt; + unsigned mTile = mBase / instM; + unsigned numRepM = state.tilePlan.tiles.size() / state.tilePlan.numRepN; + if (mTile >= numRepM || kTile >= state.tilePlan.numRepK) + return std::nullopt; + unsigned chunkIdx = mTile * state.tilePlan.numRepK + kTile; + unsigned validElems = (instM * validK) / warpSize; + auto elems = extractContiguousRank2WmmaOperandChunk( + loc, state.aElemsAll, chunkIdx, elemsPerThread, validElems, elemTy, + typeConverter, rewriter); + auto packed = packElemsToI32(b, loc, elems, elemTy, typeConverter); + return packToVector(b, loc, rewriter.getI32Type(), packed); +} + +static std::optional buildContiguousRank2PackedWmmaOperandB( + Location loc, const ContiguousRank2WmmaFastPathState &state, unsigned nBase, + unsigned kTile, unsigned instN, unsigned validK, unsigned warpSize, + unsigned elemsPerThread, Type elemTy, TritonLLVMOpBuilder &b, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (nBase % instN != 0) + return std::nullopt; + unsigned nTile = nBase / instN; + if (nTile >= state.tilePlan.numRepN || kTile >= state.tilePlan.numRepK) + return std::nullopt; + unsigned chunkIdx = kTile * state.tilePlan.numRepN + nTile; + unsigned validElems = (instN * validK) / warpSize; + auto elems = extractContiguousRank2WmmaOperandChunk( + loc, state.bElemsAll, chunkIdx, elemsPerThread, validElems, elemTy, + typeConverter, rewriter); + auto packed = packElemsToI32(b, loc, elems, elemTy, typeConverter); + return packToVector(b, loc, rewriter.getI32Type(), packed); +} + +static std::optional buildPackedWmmaOperandAForTile( + Location loc, + const std::optional + &contiguousRank2FastPath, + Value adaptorA, const triton::musa::WmmaDotOperandContract &aContract, + int batch, int mBase, unsigned kTile, unsigned instM, unsigned instK, + unsigned warpSize, unsigned elemsPerThread, unsigned kPadding, + unsigned validK, Type elemTy, Type packedElemTy, TritonLLVMOpBuilder &b, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (contiguousRank2FastPath) + return buildContiguousRank2PackedWmmaOperandA( + loc, *contiguousRank2FastPath, mBase, kTile, instM, validK, warpSize, + elemsPerThread, elemTy, b, typeConverter, rewriter); + + Value aVec = triton::musa::extractWmmaOperandVectorFromContract( + loc, adaptorA, aContract, typeConverter, rewriter, batch, mBase, kTile, + instK, elemsPerThread, kPadding, elemTy); + auto aElems = unpackLLVector(loc, aVec, rewriter); + auto aPacked = + packElemsToI32(b, loc, aElems, aElems.front().getType(), typeConverter); + return packToVector(b, loc, packedElemTy, aPacked); +} + +static std::optional buildPackedWmmaOperandBForTile( + Location loc, + const std::optional + &contiguousRank2FastPath, + Value adaptorB, const triton::musa::WmmaDotOperandContract &bContract, + int batch, int nBase, unsigned kTile, unsigned instN, unsigned instK, + unsigned warpSize, unsigned elemsPerThread, unsigned kPadding, + unsigned validK, Type elemTy, Type packedElemTy, TritonLLVMOpBuilder &b, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (contiguousRank2FastPath) + return buildContiguousRank2PackedWmmaOperandB( + loc, *contiguousRank2FastPath, nBase, kTile, instN, validK, warpSize, + elemsPerThread, elemTy, b, typeConverter, rewriter); + + Value bVec = triton::musa::extractWmmaOperandVectorFromContract( + loc, adaptorB, bContract, typeConverter, rewriter, batch, nBase, kTile, + instK, elemsPerThread, kPadding, elemTy); + auto bElems = unpackLLVector(loc, bVec, rewriter); + auto bPacked = + packElemsToI32(b, loc, bElems, bElems.front().getType(), typeConverter); + return packToVector(b, loc, packedElemTy, bPacked); +} + +template +LogicalResult convertWMMADotImpl(DotLikeOp op, DotLikeAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto dTy = cast(op.getResult().getType()); + auto mmaEnc = dyn_cast(dTy.getEncoding()); + if (!mmaEnc) + return op.emitError("MUSA WMMA: expected #ttg.musa_wmma result encoding"); + if (!triton::musa::supportsMusaWmmaEncoding(mmaEnc)) + return op.emitError("MUSA WMMA: unsupported result encoding version"); + + Value opA = getWmmaOperandA(op); + Value opB = getWmmaOperandB(op); + Value adaptorA = getWmmaAdaptorA(adaptor); + Value adaptorB = getWmmaAdaptorB(adaptor); + Value adaptorC = getWmmaAdaptorC(adaptor); + auto aTy = cast(opA.getType()); + auto bTy = cast(opB.getType()); + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + auto dElemTy = dTy.getElementType(); + auto aContract = triton::musa::resolveWmmaDotOperandContract(aTy, 0); + auto bContract = triton::musa::resolveWmmaDotOperandContract(bTy, 1); + if (failed(aContract) || failed(bContract)) + return op.emitError("MUSA WMMA operands must use DotOperandEncodingAttr"); + if (aContract->dotEncoding.getParent() != dTy.getEncoding() || + bContract->dotEncoding.getParent() != dTy.getEncoding()) + return op.emitError("MUSA WMMA operands must point to the same " + "#ttg.musa_wmma result encoding"); + Value useCFlag = + materializeUseCFlag(loc, getWmmaUseCValue(op, adaptor), rewriter); + std::optional useCConst = getBoolFromConstant(useCFlag); + if (aElemTy != bElemTy) + return op.emitError("MUSA WMMA requires A/B element types to match"); + + auto instrShape = mmaEnc.getInstrShape(); + if (instrShape.size() != 3) + return op.emitError("MUSA WMMA expects 3D instrShape"); + + auto signature = triton::musa::lookupWmmaIntrinsic(aElemTy, instrShape); + if (!signature) + return op.emitError("MUSA WMMA: unsupported instrShape or element type"); + + const unsigned instM = instrShape[0]; + const unsigned instN = instrShape[1]; + const unsigned instK = instrShape[2]; + auto layoutA = getWmmaLayoutA(op); + auto layoutB = getWmmaLayoutB(op); + + const unsigned warpSize = gpu::lookupThreadsPerWarp(rewriter); + auto contiguousRank2FastPath = buildContiguousRank2WmmaFastPathState( + loc, adaptorA, adaptorB, aTy, dTy, *aContract, *bContract, layoutA, + layoutB, instM, instN, instK, warpSize, rewriter); + + auto cLinearLayout = mmaEnc.toLinearLayout(dTy.getShape()); + + auto fc = ::mlir::unpackLLElements(loc, adaptorC, rewriter); + + const unsigned aElemsPerThread = (instM * instK) / warpSize; + const unsigned bElemsPerThread = (instN * instK) / warpSize; + const unsigned cElemsPerThread = (instM * instN) / warpSize; + + if (aElemsPerThread == 0 || bElemsPerThread == 0 || cElemsPerThread == 0) + return op.emitError("MUSA WMMA: invalid per-thread element size"); + + unsigned kDim = static_cast(aTy.getShape().back()); + unsigned numRepK = std::max(1u, ceil(kDim, instK)); + + unsigned accPackedLen = (cElemsPerThread * dElemTy.getIntOrFloatBitWidth()); + if (accPackedLen % 32 != 0) + return op.emitError("MUSA WMMA: accumulator packing misaligned"); + accPackedLen /= 32; + if (accPackedLen == 0) + return op.emitError("MUSA WMMA: invalid accumulator packing size"); + + Type accPackedElemTy = rewriter.getI32Type(); + Type accVecTy = accPackedLen == 1 ? accPackedElemTy + : vec_ty(accPackedElemTy, accPackedLen); + + WmmaTilePlan tilePlan; + tilePlan.numRepK = numRepK; + tilePlan.numRepN = 0; + if (contiguousRank2FastPath) { + tilePlan = contiguousRank2FastPath->tilePlan; + for (WmmaTileState &tile : tilePlan.tiles) + tile.regBase *= cElemsPerThread; + if (fc.size() != tilePlan.tiles.size() * cElemsPerThread) + return op.emitError("MUSA WMMA: accumulator register size mismatch"); + } else { + auto genericTilePlan = + buildGenericWmmaTilePlan(ctx, cLinearLayout, dTy.getRank(), instM, + instN, cElemsPerThread, fc.size()); + if (failed(genericTilePlan)) + return op.emitError("MUSA WMMA: failed to derive repetition layout"); + tilePlan = *genericTilePlan; + tilePlan.numRepK = numRepK; + } + + unsigned kPaddingA = 0; + unsigned kPaddingB = 0; + if (instK > kDim) { + unsigned paddingFactor = instK / kDim; + kPaddingA = aElemsPerThread - (aElemsPerThread / paddingFactor); + kPaddingB = bElemsPerThread - (bElemsPerThread / paddingFactor); + } + + for (const WmmaTileState &tileState : tilePlan.tiles) { + SmallVector accElems(fc.begin() + tileState.regBase, + fc.begin() + tileState.regBase + + cElemsPerThread); + auto packedAccElems = packElemsToI32( + b, loc, accElems, accElems.front().getType(), typeConverter); + if (packedAccElems.size() != accPackedLen) + return op.emitError("MUSA WMMA: accumulator pack size mismatch"); + Value accVec = packToVector(b, loc, accPackedElemTy, packedAccElems); + if (!useCConst || !*useCConst) { + Value zeroAcc = LLVM::ZeroOp::create(rewriter, loc, accVecTy); + accVec = + selectAccumulatorVector(loc, useCFlag, accVec, zeroAcc, rewriter); + } + + for (unsigned kTile = 0; kTile < tilePlan.numRepK; ++kTile) { + unsigned kOffset = kTile * instK; + unsigned validK = (kOffset < kDim) ? std::min(instK, kDim - kOffset) : 0; + auto aOp = buildPackedWmmaOperandAForTile( + loc, contiguousRank2FastPath, adaptorA, *aContract, tileState.batch, + tileState.mBase, kTile, instM, instK, warpSize, aElemsPerThread, + kPaddingA, validK, aElemTy, accPackedElemTy, b, typeConverter, + rewriter); + if (!aOp) + return op.emitError("MUSA WMMA: failed to materialize A fragment"); + auto bOp = buildPackedWmmaOperandBForTile( + loc, contiguousRank2FastPath, adaptorB, *bContract, tileState.batch, + tileState.nBase, kTile, instN, instK, warpSize, bElemsPerThread, + kPaddingB, validK, bElemTy, accPackedElemTy, b, typeConverter, + rewriter); + if (!bOp) + return op.emitError("MUSA WMMA: failed to materialize B fragment"); + SmallVector args = triton::musa::buildWmmaIntrinsicArgs( + loc, *aOp, *bOp, accVec, layoutA, layoutB, *signature, rewriter); + auto call = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, signature->name, TypeRange{accVecTy}, args); + accVec = call.getResult(0); + } + + SmallVector accPackedElems; + if (auto vecTy = dyn_cast(accVec.getType())) { + for (unsigned i = 0; i < vecTy.getNumElements(); ++i) + accPackedElems.push_back( + b.extract_element(accPackedElemTy, accVec, b.i32_val(i))); + } else { + accPackedElems.push_back(accVec); + } + + auto accUnpacked = unpackI32ToElems(b, loc, accPackedElems, + typeConverter->convertType(dElemTy)); + if (accUnpacked.size() != cElemsPerThread) + return op.emitError("MUSA WMMA: accumulator unpack size mismatch"); + for (unsigned i = 0; i < accUnpacked.size(); ++i) + fc[tileState.regBase + i] = accUnpacked[i]; + } + + Value res = ::mlir::packLLElements(loc, typeConverter, fc, rewriter, dTy); + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult convertWMMADot(triton::musa::WmmaDotOp op, + triton::musa::WmmaDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + return convertWMMADotImpl(op, adaptor, typeConverter, rewriter); +} + +} // namespace mlir::triton::MUSA diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 0000000000..a447c06aec --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,891 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +static Type getConvertedElementType(Type type, + const TypeConverter *typeConverter) { + Type convertedType = typeConverter->convertType(type); + if (auto vecType = dyn_cast(convertedType)) + return vecType.getElementType(); + if (auto structType = dyn_cast(convertedType)) { + if (!structType.getBody().empty()) + return structType.getBody().front(); + } + return convertedType; +} + +static Value maybeBitcastSameWidth(TritonLLVMOpBuilder &b, Value value, + Type targetTy) { + if (value.getType() == targetTy) + return value; + Type srcTy = value.getType(); + if (srcTy.isIntOrFloat() && targetTy.isIntOrFloat() && + srcTy.getIntOrFloatBitWidth() == targetTy.getIntOrFloatBitWidth()) + return b.bitcast(value, targetTy); + return value; +} + +static Value createI32SignedDivCall(Operation *op, Location loc, + ConversionPatternRewriter &rewriter, + Value lhs, Value rhs) { + Type funcType = LLVM::LLVMFunctionType::get(i32_ty, {i32_ty, i32_ty}); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, "__mt_sdiv_i32", funcType); + return LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange{lhs, rhs}) + .getResult(); +} + +struct DivSIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::DivSIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type logicalElemTy = + this->typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = + getConvertedElementType(op.getType(), this->typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + + auto intTy = dyn_cast(logicalElemTy); + auto elemWidth = intTy.getWidth(); + Value out; + if (!intTy || (elemWidth != 16 && elemWidth != 32)) { + out = LLVM::SDivOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } + + Type widenedTy; + if (elemWidth == 16) { + widenedTy = rewriter.getI32Type(); + Value lhsWide = LLVM::SExtOp::create(rewriter, loc, widenedTy, lhs); + Value rhsWide = LLVM::SExtOp::create(rewriter, loc, widenedTy, rhs); + Value outWide = + createI32SignedDivCall(op, loc, rewriter, lhsWide, rhsWide); + out = LLVM::TruncOp::create(rewriter, loc, logicalElemTy, outWide); + } else { + out = createI32SignedDivCall(op, loc, rewriter, lhs, rhs); + } + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct RemSIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::RemSIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type logicalElemTy = + this->typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = + getConvertedElementType(op.getType(), this->typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + + auto intTy = dyn_cast(logicalElemTy); + auto elemWidth = intTy.getWidth(); + Value out; + if (!intTy || (elemWidth != 16 && elemWidth != 32)) { + out = LLVM::SRemOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } + + if (elemWidth == 16) { + Type widenedTy = rewriter.getI32Type(); + Value lhsWide = LLVM::SExtOp::create(rewriter, loc, widenedTy, lhs); + Value rhsWide = LLVM::SExtOp::create(rewriter, loc, widenedTy, rhs); + Value remWide = + LLVM::SRemOp::create(rewriter, loc, widenedTy, lhsWide, rhsWide); + out = LLVM::TruncOp::create(rewriter, loc, logicalElemTy, remWide); + } else { + Value quot = createI32SignedDivCall(op, loc, rewriter, lhs, rhs); + Value prod = LLVM::MulOp::create(rewriter, loc, logicalElemTy, quot, rhs); + out = LLVM::SubOp::create(rewriter, loc, logicalElemTy, lhs, prod); + } + + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct FpToFpOpConversion + : public ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + explicit FpToFpOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : Base(typeConverter, axisAnalysisPass, benefit) {} + + struct Fp8ConversionDesc { + StringRef funcName; + size_t numElements; + }; + + static bool isFp8Type(Type ty) { + return isa(ty); + } + + static StringRef getFp8ConversionIntrinsic(StringRef funcName) { + if (funcName == "__mt_tt_f16_to_e4m3") + return "llvm.musa.f162e4m3.rn"; + if (funcName == "__mt_tt_f16_to_e5m2") + return "llvm.musa.f162e5m2.rn"; + return {}; + } + + static bool isFp8Burst2Enabled() { + std::string envValue = + mlir::triton::tools::getStrEnv("TRITON_MUSA_ENABLE_FP8_BURST2"); + if (envValue.empty()) + return false; + std::transform(envValue.begin(), envValue.end(), envValue.begin(), + [](unsigned char c) { return std::tolower(c); }); + return envValue == "1" || envValue == "true" || envValue == "on"; + } + + std::pair + getFp8ConversionFunc(Type srcTy, Type dstTy, + std::optional roundingMode, + bool enableFp8Burst2) const { + auto F8E4M3TyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + auto undefRounding = static_cast(-1); + + static DenseMap, + SmallVector> + conversionTable = { + // F8 -> F32 + {{F8E4M3TyID, F32TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2f32", 2}, {"__mt_tt_e4m3_to_f32", 1}}}, + {{F8E5M2TyID, F32TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2f32", 2}, {"__mt_tt_e5m2_to_f32", 1}}}, + // F8 -> F16 + {{F8E4M3TyID, F16TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2f16", 2}, {"__mt_tt_e4m3_to_f16", 1}}}, + {{F8E5M2TyID, F16TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2f16", 2}, {"__mt_tt_e5m2_to_f16", 1}}}, + // F8 -> BF16 + {{F8E4M3TyID, BF16TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2bf16", 2}, {"__mt_tt_e4m3_to_bf16", 1}}}, + {{F8E5M2TyID, BF16TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2bf16", 2}, {"__mt_tt_e5m2_to_bf16", 1}}}, + // F32 -> F8 + {{F32TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f32_to_v2e4m3", 2}, {"__mt_tt_f32_to_e4m3", 1}}}, + {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f32_to_v2e5m2", 2}, {"__mt_tt_f32_to_e5m2", 1}}}, + // F16 -> F8 + {{F16TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f16_to_v2e4m3", 2}, {"__mt_tt_f16_to_e4m3", 1}}}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f16_to_v2e5m2", 2}, {"__mt_tt_f16_to_e5m2", 1}}}, + // BF16 -> F8 + {{BF16TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2bf16_to_v2e4m3", 2}, {"__mt_tt_bf16_to_e4m3", 1}}}, + {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2bf16_to_v2e5m2", 2}, {"__mt_tt_bf16_to_e5m2", 1}}}, + }; + + auto key = std::make_tuple(srcTy.getTypeID(), dstTy.getTypeID(), + roundingMode.value_or(undefRounding)); + auto it = conversionTable.find(key); + if (it == conversionTable.end()) { + llvm::report_fatal_error("Unsupported MUSA fp8 conversion"); + } + const auto &entries = it->second; + const auto &entry = enableFp8Burst2 ? entries.front() : entries.back(); + return {entry.funcName, entry.numElements}; + } + + static SmallVector + convertFp8(const LLVMTypeConverter *typeConverter, FpToFpOp op, Location loc, + ConversionPatternRewriter &rewriter, const SmallVector &v, + Type srcElementType, Type dstElementType, StringRef funcName) { + TritonLLVMOpBuilder b(loc, rewriter); + const size_t numElements = v.size(); + Type inpType; + Type outType; + Value inVals; + + if (numElements == 1) { + inpType = typeConverter->convertType(srcElementType); + outType = typeConverter->convertType(dstElementType); + inVals = v[0]; + } else { + inpType = vec_ty(typeConverter->convertType(srcElementType), numElements); + outType = vec_ty(typeConverter->convertType(dstElementType), numElements); + inVals = b.undef(inpType); + for (size_t i = 0; i < numElements; ++i) + inVals = b.insert_element(inpType, inVals, v[i], b.i32_val(i)); + } + + Value outVals; + if (numElements == 1) { + if (StringRef intrinsicName = getFp8ConversionIntrinsic(funcName); + !intrinsicName.empty()) { + auto intrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsicName, TypeRange{outType}, + ValueRange{inVals}); + outVals = intrinsic.getResult(0); + } + } + if (!outVals) { + Type funcType = LLVM::LLVMFunctionType::get(outType, inpType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + outVals = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange{inVals}) + .getResult(); + } + + SmallVector ret; + Type outElemType = typeConverter->convertType(dstElementType); + for (size_t i = 0; i < numElements; ++i) { + ret.push_back(numElements == 1 ? outVals + : b.extract_element(outElemType, outVals, + b.i32_val(i))); + } + return ret; + } + + static Value convertBf16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + SmallVector ops = {v}; + SmallVector resultTypes{f32_ty}; + auto intrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.bfloat162float", resultTypes, ops); + return intrinsic.getResult(0); + } + + static Value convertFp32ToBf16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, RoundingMode rounding) { + if (rounding == RoundingMode::RTZ) { + TritonLLVMOpBuilder b(loc, rewriter); + auto asInt32 = b.bitcast(v, i32_ty); + auto shifted = b.lshr(i32_ty, asInt32, b.i32_val(16)); + auto truncated = b.trunc(i16_ty, shifted); + return b.bitcast(truncated, bf16_ty); + } + SmallVector ops = {v}; + SmallVector resultTypes{bf16_ty}; + auto intrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.float2bfloat16", resultTypes, ops); + return intrinsic.getResult(0); + } + + static Value convertFp32ToFp16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, RoundingMode rounding) { + switch (rounding) { + case RoundingMode::RTNE: + return LLVM::FPTruncOp::create(rewriter, loc, f16_ty, v); + case RoundingMode::RTZ: + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.f2h.rz", + f16_ty, {v}) + .getResult(0); + default: + emitError(loc) << "unsupported rounding mode for f32->f16 conversion: " + << stringifyRoundingMode(rounding); + llvm::report_fatal_error( + "unsupported rounding mode for f32->f16 conversion"); + } + } + + static Value convertSrcFpToFp32(Location loc, + ConversionPatternRewriter &rewriter, Value v, + Type srcTy) { + if (srcTy.isF32()) + return v; + if (srcTy.isF16()) + return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v); + if (srcTy.isBF16()) + return convertBf16ToFp32(loc, rewriter, v); + llvm::report_fatal_error("Unsupported MUSA fp8 RTZ source type"); + } + + static Value clearFp32SignBit(Location loc, + ConversionPatternRewriter &rewriter, Value v) { + TritonLLVMOpBuilder b(loc, rewriter); + Value bits = b.bitcast(v, i32_ty); + Value absBits = b.and_(bits, b.i32_val(0x7fffffff)); + return b.bitcast(absBits, f32_ty); + } + + static Value convertFp8Scalar(const LLVMTypeConverter *typeConverter, + FpToFpOp op, Location loc, + ConversionPatternRewriter &rewriter, Value v, + Type srcElementType, Type dstElementType, + StringRef funcName) { + auto outVals = convertFp8(typeConverter, op, loc, rewriter, {v}, + srcElementType, dstElementType, funcName); + assert(outVals.size() == 1 && "expected scalar fp8 conversion"); + return outVals.front(); + } + + Value convertFp8DowncastRTZ(Location loc, ConversionPatternRewriter &rewriter, + FpToFpOp op, Value src, Type srcElemTy, + Type dstElemTy) const { + TritonLLVMOpBuilder b(loc, rewriter); + Value srcFp32 = convertSrcFpToFp32(loc, rewriter, src, srcElemTy); + + auto [downcastFuncName, _] = + getFp8ConversionFunc(f32_ty, dstElemTy, RoundingMode::RTNE, false); + Value rtneFp8 = + convertFp8Scalar(getTypeConverter(), op, loc, rewriter, srcFp32, f32_ty, + dstElemTy, downcastFuncName); + + auto [upcastFuncName, __] = + getFp8ConversionFunc(dstElemTy, f32_ty, std::nullopt, false); + Value rtneRoundTrip = + convertFp8Scalar(getTypeConverter(), op, loc, rewriter, rtneFp8, + dstElemTy, f32_ty, upcastFuncName); + + Value absSrc = clearFp32SignBit(loc, rewriter, srcFp32); + Value absRtneRoundTrip = clearFp32SignBit(loc, rewriter, rtneRoundTrip); + Value roundedAwayFromZero = b.fcmp_ogt(absRtneRoundTrip, absSrc); + Value rtzFp8 = b.sub(rtneFp8, b.i8_val(1)); + return b.select(roundedAwayFromZero, rtzFp8, rtneFp8); + } + + SmallVector createDestOps(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto srcElemTy = getElementTypeOrSelf(op.getSrc().getType()); + auto dstElemTy = getElementTypeOrSelf(op.getType()); + Type packedDstTy = getConvertedElementType(op.getType(), typeConverter); + Value src = operands[0][0]; + if (isa(srcElemTy)) { + Type logicalSrcTy = typeConverter->convertType(srcElemTy); + src = maybeBitcastSameWidth(b, src, logicalSrcTy); + } + + auto roundingMode = op.getRounding(); + + bool isFp8Conversion = isFp8Type(srcElemTy) || isFp8Type(dstElemTy); + if (isFp8Conversion) { + if (isFp8Type(dstElemTy)) { + if (roundingMode.has_value() && + roundingMode.value() == RoundingMode::RTZ) { + if (!isa(srcElemTy)) { + llvm::report_fatal_error( + "Unsupported MUSA fp8 RTZ downcast source type"); + } + Type logicalSrcTy = typeConverter->convertType(srcElemTy); + SmallVector outVals; + outVals.reserve(operands.size()); + for (unsigned i = 0; i < operands.size(); ++i) { + Value inVal = operands[i][0]; + if (isa(srcElemTy)) + inVal = maybeBitcastSameWidth(b, inVal, logicalSrcTy); + outVals.push_back(convertFp8DowncastRTZ(loc, rewriter, op, inVal, + srcElemTy, dstElemTy)); + } + return outVals; + } + if (!roundingMode.has_value() || + roundingMode.value() != RoundingMode::RTNE) { + llvm::report_fatal_error( + "MUSA fp8 downcast requires RTNE rounding mode"); + } + } + auto [funcName, numElements] = getFp8ConversionFunc( + srcElemTy, dstElemTy, roundingMode, isFp8Burst2Enabled()); + Type logicalSrcTy = typeConverter->convertType(srcElemTy); + SmallVector inVals; + for (unsigned i = 0; i < std::min(numElements, operands.size()); ++i) { + Value inVal = operands[i][0]; + if (isa(srcElemTy)) + inVal = maybeBitcastSameWidth(b, inVal, logicalSrcTy); + inVals.push_back(inVal); + } + inVals.resize(numElements, + b.undef(typeConverter->convertType(srcElemTy))); + auto outVals = convertFp8(getTypeConverter(), op, loc, rewriter, inVals, + srcElemTy, dstElemTy, funcName); + outVals.resize(std::min(numElements, operands.size())); + return outVals; + } + + if (srcElemTy.isBF16() && dstElemTy.isF32()) { + return {maybeBitcastSameWidth(b, convertBf16ToFp32(loc, rewriter, src), + packedDstTy)}; + } + if (srcElemTy.isF32() && dstElemTy.isBF16()) { + auto rounding = op.getRounding().value_or(RoundingMode::RTNE); + return {maybeBitcastSameWidth( + b, convertFp32ToBf16(loc, rewriter, src, rounding), packedDstTy)}; + } + if (srcElemTy.isF16() && dstElemTy.isF32()) { + Value out = LLVM::FPExtOp::create(rewriter, loc, f32_ty, src); + return {maybeBitcastSameWidth(b, out, packedDstTy)}; + } + if (srcElemTy.isF32() && dstElemTy.isF16()) { + auto rounding = op.getRounding().value_or(RoundingMode::RTNE); + Value out = convertFp32ToFp16(loc, rewriter, src, rounding); + return {maybeBitcastSameWidth(b, out, packedDstTy)}; + } + if (srcElemTy.isF16() && dstElemTy.isBF16()) { + Value tmp = LLVM::FPExtOp::create(rewriter, loc, f32_ty, src); + auto rounding = op.getRounding().value_or(RoundingMode::RTNE); + Value out = convertFp32ToBf16(loc, rewriter, tmp, rounding); + return {maybeBitcastSameWidth(b, out, packedDstTy)}; + } + if (srcElemTy.isBF16() && dstElemTy.isF16()) { + Value tmp = convertBf16ToFp32(loc, rewriter, src); + Value out = LLVM::FPTruncOp::create(rewriter, loc, f16_ty, tmp); + return {maybeBitcastSameWidth(b, out, packedDstTy)}; + } + if (srcElemTy == dstElemTy) { + return {src}; + } + + // Fallback to LLVM FP trunc/ext when applicable. + if (isa(srcElemTy) && isa(dstElemTy)) { + if (srcElemTy.getIntOrFloatBitWidth() < dstElemTy.getIntOrFloatBitWidth()) + return {maybeBitcastSameWidth( + b, + LLVM::FPExtOp::create(rewriter, loc, + typeConverter->convertType(dstElemTy), src), + packedDstTy)}; + if (srcElemTy.getIntOrFloatBitWidth() > dstElemTy.getIntOrFloatBitWidth()) + return {maybeBitcastSameWidth( + b, + LLVM::FPTruncOp::create(rewriter, loc, + typeConverter->convertType(dstElemTy), src), + packedDstTy)}; + } + + return {}; + } +}; + +template +Value emitDualBF16ElementwiseOp(Location loc, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands) { + auto lhs = + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]); + auto rhs = + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]); + auto result = OpType::create(rewriter, loc, f32_ty, lhs, rhs); + return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result, + RoundingMode::RTNE); +} + +struct PreciseSqrtOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit PreciseSqrtOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : Base(typeConverter, axisAnalysisPass, benefit) {} + + SmallVector createDestOps(triton::PreciseSqrtOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Value input = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + + Type f64Ty = rewriter.getF64Type(); + Type f32Ty = rewriter.getF32Type(); + Value inputF64 = LLVM::FPExtOp::create(rewriter, loc, f64Ty, input); + + Type funcType = getFunctionType(f64Ty, ValueRange{inputF64}); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, "__mt_sqrt_f64", funcType); + Value sqrtResultF64 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange{inputF64}) + .getResult(); + Value resultF32 = + LLVM::FPTruncOp::create(rewriter, loc, f32Ty, sqrtResultF64); + return {maybeBitcastSameWidth(b, resultF32, packedElemTy)}; + } +}; + +struct PreciseDivOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit PreciseDivOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : Base(typeConverter, axisAnalysisPass, benefit) {} + + SmallVector createDestOps(triton::PreciseDivFOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + + // precise_divf is defined for f32; compute in f64 and cast back for + // improved precision. + Type f64Ty = rewriter.getF64Type(); + Type f32Ty = rewriter.getF32Type(); + Value lhsF64 = LLVM::FPExtOp::create(rewriter, loc, f64Ty, lhs); + Value rhsF64 = LLVM::FPExtOp::create(rewriter, loc, f64Ty, rhs); + Value divResultF64 = + LLVM::FDivOp::create(rewriter, loc, f64Ty, lhsF64, rhsF64); + Value resultF32 = + LLVM::FPTruncOp::create(rewriter, loc, f32Ty, divResultF64); + return {maybeBitcastSameWidth(b, resultF32, packedElemTy)}; + } +}; + +struct FDivOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {emitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + Value out = LLVM::FDivOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct FMulOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::MulFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {emitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + Value out = LLVM::FMulOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct FAddOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {emitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + Value out = LLVM::FAddOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct FSubOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SubFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {emitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } + Type logicalElemTy = + typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = getConvertedElementType(op.getType(), typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + Value out = LLVM::FSubOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +template +struct FPBinaryBitcastOpConversion + : ElementwiseOpConversionBase> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(SrcOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type logicalElemTy = + this->typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = + getConvertedElementType(op.getType(), this->typeConverter); + Value lhs = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value rhs = maybeBitcastSameWidth(b, operands[0][1], logicalElemTy); + Value out = DstOp::create(rewriter, loc, logicalElemTy, lhs, rhs); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +template +struct FPUnaryBitcastOpConversion + : ElementwiseOpConversionBase> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(SrcOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + Type logicalElemTy = + this->typeConverter->convertType(getElementType(op.getResult())); + Type packedElemTy = + getConvertedElementType(op.getType(), this->typeConverter); + Value src = maybeBitcastSameWidth(b, operands[0][0], logicalElemTy); + Value out = DstOp::create(rewriter, loc, logicalElemTy, src, + adaptor.getAttributes().getValue()); + return {maybeBitcastSameWidth(b, out, packedElemTy)}; + } +}; + +struct SIToFPOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16()) { + Value f32Val = + LLVM::SIToFPOp::create(rewriter, loc, f32_ty, operands[0][0]); + return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, f32Val, + RoundingMode::RTNE)}; + } + return {LLVM::SIToFPOp::create(rewriter, loc, elemTy, operands[0][0])}; + } +}; + +struct FPToSIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::FPToSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (inElemTy.isBF16()) { + Value f32Val = + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]); + return {LLVM::FPToSIOp::create(rewriter, loc, elemTy, f32Val)}; + } + return {LLVM::FPToSIOp::create(rewriter, loc, elemTy, operands[0][0])}; + } +}; + +struct ExtFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::ExtFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto inElemTy = getElementType(op.getIn()); + auto outElemTy = getElementType(op.getOut()); + Type packedOutTy = getConvertedElementType(op.getType(), typeConverter); + Value src = operands[0][0]; + if (isa(inElemTy)) { + Type logicalInTy = typeConverter->convertType(inElemTy); + src = maybeBitcastSameWidth(b, src, logicalInTy); + } + + Value out; + if (inElemTy.isBF16()) { + out = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, src); + return {maybeBitcastSameWidth(b, out, packedOutTy)}; + } + Type logicalOutTy = typeConverter->convertType(outElemTy); + out = LLVM::FPExtOp::create(rewriter, loc, logicalOutTy, src); + return {maybeBitcastSameWidth(b, out, packedOutTy)}; + } +}; + +struct TruncFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::TruncFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + TritonLLVMOpBuilder b(loc, rewriter); + auto inElemTy = getElementType(op.getIn()); + auto outElemTy = getElementType(op.getOut()); + Type packedOutTy = getConvertedElementType(op.getType(), typeConverter); + Value src = operands[0][0]; + if (isa(inElemTy)) { + Type logicalInTy = typeConverter->convertType(inElemTy); + src = maybeBitcastSameWidth(b, src, logicalInTy); + } + + Value out; + if (outElemTy.isBF16()) { + out = FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, src, + RoundingMode::RTNE); + return {maybeBitcastSameWidth(b, out, packedOutTy)}; + } + Type logicalOutTy = typeConverter->convertType(outElemTy); + out = LLVM::FPTruncOp::create(rewriter, loc, logicalOutTy, src); + return {maybeBitcastSameWidth(b, out, packedOutTy)}; + } +}; + +} // namespace + +void mlir::triton::MUSA::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int /*computeCapability*/, + const TargetInfo &targetInfo, PatternBenefit benefit) { + PatternBenefit priorityBenefit(benefit.getBenefit() + 1); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add>( + typeConverter, axisInfoAnalysis, priorityBenefit); + patterns.add>( + typeConverter, axisInfoAnalysis, priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + patterns.add(typeConverter, axisInfoAnalysis, + priorityBenefit); + mlir::triton::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum(); + mlir::triton::populateMinMaxFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported, + benefit); + mlir::triton::populateClampFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/LoadStoreOpToLLVM.cpp new file mode 100644 index 0000000000..dbcd484189 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -0,0 +1,1561 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include +#include +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +inline constexpr llvm::StringLiteral kInplaceLoadAttr = + "musa.inplace_load_candidate"; + +static triton::gpu::LocalAllocOp findRootLocalAlloc(Value memDesc) { + Value cur = memDesc; + while (cur) { + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + break; + if (auto localAllocOp = dyn_cast(defOp)) + return localAllocOp; + if (auto indexOp = dyn_cast(defOp)) { + cur = indexOp.getSrc(); + continue; + } + if (auto subsliceOp = dyn_cast(defOp)) { + cur = subsliceOp.getSrc(); + continue; + } + if (auto reinterpretOp = + dyn_cast(defOp)) { + cur = reinterpretOp.getSrc(); + continue; + } + if (auto transOp = dyn_cast(defOp)) { + cur = transOp.getSrc(); + continue; + } + if (auto reshapeOp = dyn_cast(defOp)) { + cur = reshapeOp.getSrc(); + continue; + } + break; + } + return {}; +} + +static bool requiresAbsoluteSwizzledAsyncCopy(MemDescType memDescTy) { + if (memDescTy.getShape().size() != 2) + return false; + auto swizzle = triton::musa::resolveTMESwizzleConfigFromEncoding(memDescTy); + return succeeded(swizzle) && swizzle->swizzleGranularity != + triton::musa::TMESwizzleGranularity::SG_NONE; +} + +struct PH1TMESwizzleValueConfig { + Value granularityBytes; + Value strideBytes; + Value lineBytes; + bool hasSwizzle = false; +}; + +static FailureOr +resolvePH1TMESwizzleValueConfig(Location loc, MemDescType memDescTy, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + PH1TMESwizzleValueConfig config; + + auto emitConfig = [&](triton::musa::TMESwizzleGranularity granularity, + triton::musa::TMESwizzleStride stride, + triton::musa::TMESwizzleLine line) { + config.hasSwizzle = + granularity != triton::musa::TMESwizzleGranularity::SG_NONE; + config.granularityBytes = b.i32_val(static_cast( + triton::musa::getSwizzleGranularityBytes(granularity))); + config.strideBytes = b.i32_val( + static_cast(triton::musa::getSwizzleStrideBytes(stride))); + config.lineBytes = b.i32_val( + static_cast(triton::musa::getSwizzleLineBytes(line))); + }; + + auto swizzle = triton::musa::resolveTMESwizzleConfigFromEncoding(memDescTy); + if (failed(swizzle)) + return failure(); + + emitConfig(swizzle->swizzleGranularity, swizzle->swizzleStride, + swizzle->swizzleLine); + return config; +} + +static Value +applyPH1TMESwizzleToByteAddressValue(TritonLLVMOpBuilder &b, Value addrBytes, + const PH1TMESwizzleValueConfig &config) { + if (!config.hasSwizzle) + return addrBytes; + + Value lineOffset = b.urem(addrBytes, config.lineBytes); + Value lineId = b.udiv(addrBytes, config.lineBytes); + Value swizzleGroup = b.udiv(config.strideBytes, config.granularityBytes); + Value swizzleLineId = b.urem(lineId, swizzleGroup); + Value sectorInLine = b.udiv(lineOffset, config.granularityBytes); + Value offsetInSector = b.urem(lineOffset, config.granularityBytes); + Value targetSectorInLine = b.xor_(sectorInLine, swizzleLineId); + return b.add(b.add(b.mul(lineId, config.lineBytes), + b.mul(targetSectorInLine, config.granularityBytes)), + offsetInSector); +} + +static LogicalResult lowerLocalAllocSrcToShared( + Location loc, MLIRContext *ctx, Value regVal, MemDescType memDescTy, + SharedMemoryObject smemObj, ArrayRef inVals, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, + const mlir::triton::MUSA::TargetInfo &targetInfo) { + auto regTy = cast(regVal.getType()); + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kOffset = str_attr("offset"); + auto regLayout = toLinearLayout(regTy); + auto paddedEnc = + dyn_cast(memDescTy.getEncoding()); + LinearLayout cvt = LinearLayout::empty(); + if (paddedEnc) { + const auto &sharedLL = paddedEnc.getLinearComponent(); + cvt = regLayout.invertAndCompose(sharedLL); + } else { + auto sharedLayout = toLinearLayout(memDescTy); + cvt = regLayout.invertAndCompose(sharedLayout); + } + auto kBlock = str_attr("block"); + if (!cvt.isTrivialOver({kBlock})) + return failure(); + cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset}); + lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj, + rewriter, targetInfo); + return success(); +} + +static std::optional +getResidualDotOperandLocalLoadMaxVecElems(MemDescType memDescTy) { + unsigned elemBitWidth = memDescTy.getElementTypeBitWidth(); + if (elemBitWidth == 8) + return 1; + return std::nullopt; +} + +static SmallVector +getBlockedThreadIds(Value threadId, ArrayRef shapePerCTATile, + ArrayRef sizePerThread, ArrayRef order, + ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector threadIds(order.size()); + for (unsigned i = 0; i < order.size() - 1; ++i) { + unsigned dim = order[i]; + Value dimSize = b.i32_val( + static_cast(shapePerCTATile[dim] / sizePerThread[dim])); + Value rem = b.urem(threadId, dimSize); + threadId = b.udiv(threadId, dimSize); + threadIds[dim] = rem; + } + unsigned dim = order.back(); + threadIds[dim] = + b.urem(threadId, b.i32_val(static_cast(shapePerCTATile[dim] / + sizePerThread[dim]))); + return threadIds; +} + +static SmallVector getShapePerCTATile(BlockedEncodingAttr layout) { + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(layout.getSizePerThread(), layout.getThreadsPerWarp(), + layout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + return shapePerCTATile; +} + +static unsigned getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(layout); + unsigned mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + unsigned nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; +} + +static unsigned getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = layout.getSizePerThread(); + unsigned mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + unsigned nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; +} + +static FailureOr> +getStaticStrides(ArrayRef shape, ArrayRef order) { + if (shape.size() != order.size()) + return failure(); + + SmallVector strides(shape.size()); + uint64_t stride = 1; + for (unsigned dim : order) { + if (dim >= shape.size() || shape[dim] <= 0 || + stride > std::numeric_limits::max()) + return failure(); + strides[dim] = static_cast(stride); + stride *= static_cast(shape[dim]); + } + return strides; +} + +static unsigned product(ArrayRef values) { + return std::accumulate(values.begin(), values.end(), 1u, + std::multiplies()); +} + +static FailureOr lowerResidualFMAOperandLoad( + triton::gpu::LocalLoadOp op, Value llSrc, MemDescType memDescTy, + RankedTensorType regTy, DotOperandEncodingAttr dotEnc, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, + const mlir::triton::MUSA::TargetInfo &targetInfo) { + if (memDescTy.getShape().size() != 2 || regTy.getShape().size() != 2) + return failure(); + + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dLayout = dyn_cast(dotEnc.getParent()); + if (!dLayout) + return failure(); + + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = + LLVM::getSharedMemoryObjectFromStruct(loc, llSrc, llvmElemTy, rewriter); + auto sharedOrder = triton::gpu::getOrder(memDescTy); + if (sharedOrder.size() != 2) + return failure(); + auto strides = getStaticStrides( + memDescTy.getAllocShape().take_back(memDescTy.getRank()), sharedOrder); + if (failed(strides)) + return failure(); + + Value threadId = getThreadId(rewriter, loc); + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = llvm::to_vector(dLayout.getSizePerThread()); + auto order = llvm::to_vector(dLayout.getOrder()); + auto threadIds = getBlockedThreadIds(threadId, shapePerCTATile, sizePerThread, + order, rewriter, loc); + + auto shape = memDescTy.getShape(); + unsigned opIdx = dotEnc.getOpIdx(); + SmallVector vals; + Value smemBase = smemObj.getShmemAffineBase(loc, rewriter, memDescTy); + Type ptrTy = smemBase.getType(); + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + + if (opIdx == 0) { + Value strideAM = b.i32_val(static_cast((*strides)[0])); + Value strideAK = b.i32_val(static_cast((*strides)[1])); + unsigned kExtent = static_cast(shape[1]); + unsigned mExtent = static_cast(shape[0]); + unsigned mShapePerCTATile = getShapePerCTATileForMN(dLayout, true); + unsigned mSizePerThread = getSizePerThreadForMN(dLayout, true); + unsigned mContig = mSizePerThread; + Value threadIdM = b.urem( + threadIds[0], + b.i32_val(static_cast(llvm::divideCeil(mExtent, mContig)))); + Value threadBaseM = b.mul(threadIdM, b.i32_val(mContig)); + + SmallVector perRepShape = {1, mSizePerThread, kExtent}; + SmallVector repetitions = { + 1, llvm::divideCeil(mExtent, mShapePerCTATile), 1}; + unsigned elemsPerRep = product(perRepShape); + unsigned totalElems = elemsPerRep * product(repetitions); + vals.reserve(totalElems); + for (unsigned idx = 0; idx < totalElems; ++idx) { + auto inRepIdx = + mlir::LLVM::delinearize(idx % elemsPerRep, perRepShape, inRepOrder); + auto repIdx = + mlir::LLVM::delinearize(idx / elemsPerRep, repetitions, repOrder); + Value mCoord = b.add( + b.add(threadBaseM, b.i32_val(static_cast(inRepIdx[1]))), + b.i32_val(static_cast(repIdx[1] * mShapePerCTATile))); + Value kCoord = b.i32_val(static_cast(inRepIdx[2])); + Value offset = b.add(b.mul(mCoord, strideAM), b.mul(kCoord, strideAK)); + Value ptr = b.gep(ptrTy, llvmElemTy, smemBase, offset); + vals.push_back(targetInfo.loadDShared(rewriter, loc, ptr, std::nullopt, + llvmElemTy, b.true_val(), + op.getOperation())); + } + } else if (opIdx == 1) { + Value strideBK = b.i32_val(static_cast((*strides)[0])); + Value strideBN = b.i32_val(static_cast((*strides)[1])); + unsigned kExtent = static_cast(shape[0]); + unsigned nExtent = static_cast(shape[1]); + unsigned nShapePerCTATile = getShapePerCTATileForMN(dLayout, false); + unsigned nSizePerThread = getSizePerThreadForMN(dLayout, false); + unsigned nContig = nSizePerThread; + Value threadIdN = b.urem( + threadIds[1], + b.i32_val(static_cast(llvm::divideCeil(nExtent, nContig)))); + Value threadBaseN = b.mul(threadIdN, b.i32_val(nContig)); + + SmallVector perRepShape = {1, kExtent, nSizePerThread}; + SmallVector repetitions = { + 1, 1, llvm::divideCeil(nExtent, nShapePerCTATile)}; + unsigned elemsPerRep = product(perRepShape); + unsigned totalElems = elemsPerRep * product(repetitions); + vals.reserve(totalElems); + for (unsigned idx = 0; idx < totalElems; ++idx) { + auto inRepIdx = + mlir::LLVM::delinearize(idx % elemsPerRep, perRepShape, inRepOrder); + auto repIdx = + mlir::LLVM::delinearize(idx / elemsPerRep, repetitions, repOrder); + Value kCoord = b.i32_val(static_cast(inRepIdx[1])); + Value nCoord = b.add( + b.add(threadBaseN, b.i32_val(static_cast(inRepIdx[2]))), + b.i32_val(static_cast(repIdx[2] * nShapePerCTATile))); + Value offset = b.add(b.mul(kCoord, strideBK), b.mul(nCoord, strideBN)); + Value ptr = b.gep(ptrTy, llvmElemTy, smemBase, offset); + vals.push_back(targetInfo.loadDShared(rewriter, loc, ptr, std::nullopt, + llvmElemTy, b.true_val(), + op.getOperation())); + } + } else { + return failure(); + } + + return packLLElements(loc, typeConverter, vals, rewriter, regTy); +} + +Value maybeAnd(ConversionPatternRewriter &rewriter, Location loc, Value a, + Value b) { + if (a && b) { + return TritonLLVMOpBuilder(loc, rewriter).and_(a, b); + } + return a ? a : b; +} + +template +void emitIfPredicated(RewriterBase &rewriter, Location loc, Value pred, + Fn &&emitFn) { + if (!pred) { + emitFn(); + return; + } + if (matchPattern(pred, m_One())) { + emitFn(); + return; + } + + Block *curBlock = rewriter.getInsertionBlock(); + Block *afterBlock = + rewriter.splitBlock(curBlock, rewriter.getInsertionPoint()); + Block *thenBlock = rewriter.createBlock(afterBlock); + + rewriter.setInsertionPointToEnd(curBlock); + LLVM::CondBrOp::create(rewriter, loc, pred, thenBlock, afterBlock); + + rewriter.setInsertionPointToEnd(thenBlock); + emitFn(); + LLVM::BrOp::create(rewriter, loc, afterBlock); + + rewriter.setInsertionPointToStart(afterBlock); +} + +Value emitRedundantThreadPredicate( + const llvm::MapVector &freeVarMasks, + ConversionPatternRewriter &rewriter, Location loc, + const mlir::triton::MUSA::TargetInfo &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value zero = b.i32_val(0); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = freeVarMasks.lookup(kBlock) == 0 + ? zero + : targetInfo.getClusterCTAId(rewriter, loc); + + Value pred; + int32_t laneMask = freeVarMasks.lookup(kLane); + if (laneMask != 0) { + Value dimPred = b.icmp_eq(b.and_(laneId, b.i32_val(laneMask)), zero); + pred = maybeAnd(rewriter, loc, pred, dimPred); + } + int32_t warpMask = freeVarMasks.lookup(kWarp); + if (warpMask != 0) { + Value dimPred = b.icmp_eq(b.and_(warpId, b.i32_val(warpMask)), zero); + pred = maybeAnd(rewriter, loc, pred, dimPred); + } + int32_t blockMask = freeVarMasks.lookup(kBlock); + if (blockMask != 0) { + Value dimPred = b.icmp_eq(b.and_(blockId, b.i32_val(blockMask)), zero); + pred = maybeAnd(rewriter, loc, pred, dimPred); + } + + return pred; +} + +unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { + return index & ~freeVarMask; +} + +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisInfoAnalysis) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + unsigned contiguity = axisInfoAnalysis.getContiguity(ptr); + unsigned pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + if (pointeeBitWidth == 0) + return 1; + // Keep vectorized memory operations within 128-bit lanes. + unsigned maxVec = std::max(1u, 128u / pointeeBitWidth); + return std::min(contiguity, maxVec); +} + +StringRef getMusaMemcpyG2SIntrinsic(Type elemTy) { + switch (getIntOrFloatOrPtrBitWidth(elemTy)) { + case 8: + case 16: + case 32: + case 64: + return "llvm.musa.memcpy.g2s"; + default: + return {}; + } +} + +struct LoadOpConversion : public ConvertOpToLLVMPattern { + LoadOpConversion(LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo), axisInfoAnalysis(axisInfoAnalysis) {} + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + auto *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(op.getType())); + + Value ptr = op.getPtr(); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + auto ptrElems = ::mlir::unpackLLElements(loc, llPtr, rewriter); + unsigned numElems = ptrElems.size(); + + unsigned vec = getVectorSize(ptr, axisInfoAnalysis); + vec = std::min(vec, numElems); + + bool forceScalarSharedByteLoad = false; + if (!ptrElems.empty()) { + if (auto llPtrTy = + dyn_cast(ptrElems.front().getType())) { + constexpr unsigned kSharedAddrSpace = 3; + forceScalarSharedByteLoad = + llPtrTy.getAddressSpace() == kSharedAddrSpace && + valueElemTy.getIntOrFloatBitWidth() == 8; + } + } + + SmallVector maskElems; + if (llMask) { + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(op.getMask())); + maskElems = ::mlir::unpackLLElements(loc, llMask, rewriter); + assert(maskElems.size() == numElems); + } + + SmallVector otherElems; + if (llOther) { + otherElems = ::mlir::unpackLLElements(loc, llOther, rewriter); + assert(otherElems.size() == numElems); + } + + while (vec > 1 && (numElems % vec != 0)) { + vec /= 2; + } + vec = std::max(1u, vec); + if (forceScalarSharedByteLoad) + vec = 1; + + auto freeVarMasks = getFreeVariableMasks(ptr.getType()); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + bool useInplaceLoad = op->hasAttr(kInplaceLoadAttr); + + SmallVector loaded; + loaded.reserve(numElems); + + if (vec == 1) { + for (unsigned i = 0; i < numElems; ++i) { + if (auto canonical = getCanonicalIndex(i, regMask); i != canonical) { + loaded.push_back(loaded[canonical]); + continue; + } + + Value pred = b.true_val(); + if (!maskElems.empty()) + pred = maybeAnd(rewriter, loc, pred, maskElems[i]); + Value falseVal = + otherElems.empty() + ? LLVM::ConstantOp::create(rewriter, loc, valueElemTy, + rewriter.getZeroAttr(valueElemTy)) + : otherElems[i]; + Value val = useInplaceLoad + ? LLVM::MUSA::llInplaceLoad(rewriter, loc, ptrElems[i], + valueElemTy, pred, falseVal) + : LLVM::MUSA::llLoad(rewriter, loc, ptrElems[i], + valueElemTy, pred, falseVal); + loaded.push_back(val); + } + } else { + auto vecTy = LLVM::getVectorType(valueElemTy, vec); + for (unsigned vecStart = 0; vecStart < numElems; vecStart += vec) { + unsigned canonicalVecStart = getCanonicalIndex(vecStart, regMask); + if (vecStart != canonicalVecStart) { + for (unsigned elemIdx = 0; elemIdx < vec; ++elemIdx) + loaded.push_back(loaded[canonicalVecStart + elemIdx]); + continue; + } + + Value pred = b.true_val(); + if (!maskElems.empty()) + pred = maybeAnd(rewriter, loc, pred, maskElems[vecStart]); + + Value falseVal; + if (otherElems.empty()) { + auto zeroAttr = rewriter.getZeroAttr(valueElemTy); + auto dense = + DenseElementsAttr::get(cast(vecTy), zeroAttr); + falseVal = LLVM::ConstantOp::create(rewriter, loc, vecTy, dense); + } else { + falseVal = LLVM::UndefOp::create(rewriter, loc, vecTy); + for (unsigned elemIdx = 0; elemIdx < vec; ++elemIdx) { + falseVal = b.insert_element(vecTy, falseVal, + otherElems[vecStart + elemIdx], + b.i32_val(elemIdx)); + } + } + + Value vecVal = + useInplaceLoad + ? LLVM::MUSA::llInplaceLoad(rewriter, loc, ptrElems[vecStart], + vecTy, pred, falseVal) + : LLVM::MUSA::llLoad(rewriter, loc, ptrElems[vecStart], vecTy, + pred, falseVal); + for (unsigned elemIdx = 0; elemIdx < vec; ++elemIdx) { + loaded.push_back( + b.extract_element(valueElemTy, vecVal, b.i32_val(elemIdx))); + } + } + } + + Value packed = ::mlir::packLLElements(loc, typeConverter, loaded, rewriter, + op.getType()); + rewriter.replaceOp(op, packed); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisInfoAnalysis; +}; + +struct StoreOpConversion : public ConvertOpToLLVMPattern { + StoreOpConversion(LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo), axisInfoAnalysis(axisInfoAnalysis) {} + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + Value ptr = op.getPtr(); + + Value llPtr = adaptor.getPtr(); + Value llVal = adaptor.getValue(); + Value llMask = adaptor.getMask(); + + auto ptrElems = ::mlir::unpackLLElements(loc, llPtr, rewriter); + auto valElems = ::mlir::unpackLLElements(loc, llVal, rewriter); + unsigned numElems = ptrElems.size(); + assert(numElems == valElems.size()); + + unsigned vec = getVectorSize(ptr, axisInfoAnalysis); + vec = std::min(vec, numElems); + + SmallVector maskElems; + if (llMask) { + maskElems = ::mlir::unpackLLElements(loc, llMask, rewriter); + assert(maskElems.size() == numElems); + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(op.getMask())); + } + + while (vec > 1 && (numElems % vec != 0)) { + vec /= 2; + } + vec = std::max(1u, vec); + + auto *ctx = rewriter.getContext(); + auto freeVarMasks = getFreeVariableMasks(ptr.getType()); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + + if (vec == 1) { + for (unsigned i = 0; i < numElems; ++i) { + if (!isCanonicalIndex(i, regMask)) + continue; + Value pred = threadPred ? threadPred : b.true_val(); + if (!maskElems.empty()) { + pred = maybeAnd(rewriter, loc, pred, maskElems[i]); + } + LLVM::MUSA::llStore(rewriter, loc, ptrElems[i], valElems[i], pred); + } + rewriter.eraseOp(op); + return success(); + } + + Type valueElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(op.getValue())); + auto vecTy = LLVM::getVectorType(valueElemTy, vec); + for (unsigned vecStart = 0; vecStart < numElems; vecStart += vec) { + if (!isCanonicalIndex(vecStart, regMask)) + continue; + + Value pred = threadPred ? threadPred : b.true_val(); + if (!maskElems.empty()) { + pred = maybeAnd(rewriter, loc, pred, maskElems[vecStart]); + } + + Value storeVal = LLVM::UndefOp::create(rewriter, loc, vecTy); + for (unsigned elemIdx = 0; elemIdx < vec; ++elemIdx) { + storeVal = b.insert_element( + vecTy, storeVal, valElems[vecStart + elemIdx], b.i32_val(elemIdx)); + } + LLVM::MUSA::llStore(rewriter, loc, ptrElems[vecStart], storeVal, pred); + } + + rewriter.eraseOp(op); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisInfoAnalysis; +}; + +struct AtomicCASOpConversion + : public ConvertOpToLLVMPattern { + AtomicCASOpConversion(LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo), axisInfoAnalysis(axisInfoAnalysis) {} + + Value convertToAtomicType(Location loc, Value val, Type elemTy, + ConversionPatternRewriter &rewriter) const { + if (elemTy.isIntOrIndex()) { + return val; + } + auto intTy = rewriter.getIntegerType(elemTy.getIntOrFloatBitWidth()); + return LLVM::BitcastOp::create(rewriter, loc, intTy, val); + } + + Value convertFromAtomicType(Location loc, Value val, Type elemTy, + ConversionPatternRewriter &rewriter) const { + if (elemTy.isIntOrIndex()) { + return val; + } + return LLVM::BitcastOp::create(rewriter, loc, elemTy, val); + } + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + + auto ptrElements = + ::mlir::unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto cmpElements = + ::mlir::unpackLLElements(loc, adaptor.getCmp(), rewriter); + auto valElements = + ::mlir::unpackLLElements(loc, adaptor.getVal(), rewriter); + + auto valueTy = op.getResult().getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : valueTy; + auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); + + auto atomicOrdering = + getMemoryOrdering(op.getSem()).value_or(LLVM::AtomicOrdering::acq_rel); + auto successOrdering = atomicOrdering; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + + auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); + Value storePred = threadPred ? threadPred : b.true_val(); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + + auto emitPredicated = [&](Value pred, Type retTy, + auto emitAtomic) -> Value { + if (!pred) { + return emitAtomic(); + } + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *atomicBlock = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + endBlock->addArgument(retTy, loc); + + rewriter.setInsertionPointToEnd(curBlock); + Value undefVal = LLVM::UndefOp::create(rewriter, loc, retTy); + LLVM::CondBrOp::create(rewriter, loc, pred, atomicBlock, endBlock, + undefVal); + + rewriter.setInsertionPointToEnd(atomicBlock); + Value atom = emitAtomic(); + LLVM::BrOp::create(rewriter, loc, atom, endBlock); + + rewriter.setInsertionPointToStart(endBlock); + return endBlock->getArgument(0); + }; + + if (!tensorTy) { + Value casPtr = ptrElements.front(); + Value casCmp = cmpElements.front(); + Value casVal = valElements.front(); + auto atomicCmp = convertToAtomicType(loc, casCmp, valueElemTy, rewriter); + auto atomicVal = convertToAtomicType(loc, casVal, valueElemTy, rewriter); + Type atomicTy = atomicCmp.getType(); + + Value retVal = emitPredicated(threadPred, valueElemTy, [&]() { + auto cmpxchg = LLVM::AtomicCmpXchgOp::create( + rewriter, loc, casPtr, atomicCmp, atomicVal, successOrdering, + failureOrdering); + Value old = b.extract_val(atomicTy, cmpxchg, 0); + return convertFromAtomicType(loc, old, valueElemTy, rewriter); + }); + + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + if (!op->hasAttr("allocation.offset")) { + return rewriter.notifyMatchFailure( + op, "missing allocation.offset for scalar atomic result"); + } + + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); + targetInfo.storeDShared(rewriter, loc, atomPtr, std::nullopt, retVal, + storePred); + b.barrier(triton::gpu::AddrSpace::Local); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + return success(); + } + + SmallVector resultVals(elemsPerThread); + for (size_t i = 0; i < elemsPerThread; ++i) { + if (auto canonical = getCanonicalIndex(i, regMask); canonical != i) { + resultVals[i] = resultVals[canonical]; + continue; + } + Value casPtr = ptrElements[i]; + Value casCmp = cmpElements[i]; + Value casVal = valElements[i]; + auto atomicCmp = convertToAtomicType(loc, casCmp, valueElemTy, rewriter); + auto atomicVal = convertToAtomicType(loc, casVal, valueElemTy, rewriter); + Type atomicTy = atomicCmp.getType(); + + Value oldVal = emitPredicated(threadPred, valueElemTy, [&]() { + auto cmpxchg = LLVM::AtomicCmpXchgOp::create( + rewriter, loc, casPtr, atomicCmp, atomicVal, successOrdering, + failureOrdering); + Value old = b.extract_val(atomicTy, cmpxchg, 0); + return convertFromAtomicType(loc, old, valueElemTy, rewriter); + }); + resultVals[i] = oldVal; + } + + finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals, valueElemTy, + b, storePred, targetInfo, getTypeConverter()); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisInfoAnalysis; +}; + +struct AtomicRMWOpConversion + : public ConvertOpToLLVMPattern { + AtomicRMWOpConversion(LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo), axisInfoAnalysis(axisInfoAnalysis) {} + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + + auto atomicRmwAttr = op.getAtomicRmwOp(); + auto maybeKind = matchAtomicOp(atomicRmwAttr); + if (!maybeKind) { + return rewriter.notifyMatchFailure(op, "unsupported atomic op"); + } + auto atomicOrdering = + getMemoryOrdering(op.getSem()).value_or(LLVM::AtomicOrdering::acq_rel); + + auto ptrElements = + ::mlir::unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto valElements = + ::mlir::unpackLLElements(loc, adaptor.getVal(), rewriter); + SmallVector maskElements; + if (adaptor.getMask()) { + maskElements = ::mlir::unpackLLElements(loc, adaptor.getMask(), rewriter); + } + + auto valueTy = op.getResult().getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : valueTy; + auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); + + auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); + Value storePred = threadPred ? threadPred : b.true_val(); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + + auto emitPredicated = [&](Value pred, Type retTy, + auto emitAtomic) -> Value { + if (!pred) { + return emitAtomic(); + } + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *atomicBlock = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + endBlock->addArgument(retTy, loc); + + rewriter.setInsertionPointToEnd(curBlock); + Value undefVal = LLVM::UndefOp::create(rewriter, loc, retTy); + LLVM::CondBrOp::create(rewriter, loc, pred, atomicBlock, endBlock, + undefVal); + + rewriter.setInsertionPointToEnd(atomicBlock); + Value atom = emitAtomic(); + LLVM::BrOp::create(rewriter, loc, atom, endBlock); + + rewriter.setInsertionPointToStart(endBlock); + return endBlock->getArgument(0); + }; + + if (!tensorTy) { + Value rmwPtr = ptrElements.front(); + Value rmwVal = valElements.front(); + Value rmwMask = + maskElements.empty() ? b.true_val() : maskElements.front(); + rmwMask = maybeAnd(rewriter, loc, rmwMask, threadPred); + + auto emitAtomic = [&]() -> Value { + if (*maybeKind == LLVM::AtomicBinOp::fadd && + (valueElemTy.isF16() || valueElemTy.isF32() || + valueElemTy.isF64())) { + StringRef funcName; + Type fpType = valueElemTy; + if (valueElemTy.isF16()) { + funcName = "__mt_atomicAdd_f16"; + } else if (valueElemTy.isF32()) { + funcName = "__mt_atomicAdd_f32"; + } else { + funcName = "__mt_atomicAdd_f64"; + } + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto funcType = LLVM::LLVMFunctionType::get(fpType, {ptrTy, fpType}); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + Value addressCast = + LLVM::AddrSpaceCastOp::create(rewriter, loc, ptrTy, rmwPtr); + return LLVM::CallOp::create(rewriter, loc, funcOp, + ValueRange{addressCast, rmwVal}) + .getResult(); + } + return LLVM::AtomicRMWOp::create(rewriter, loc, *maybeKind, rmwPtr, + rmwVal, atomicOrdering) + .getResult(); + }; + + Value retVal = emitPredicated(rmwMask, valueElemTy, emitAtomic); + + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + if (!op->hasAttr("allocation.offset")) { + return rewriter.notifyMatchFailure( + op, "missing allocation.offset for scalar atomic result"); + } + + auto *ctx = rewriter.getContext(); + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); + targetInfo.storeDShared(rewriter, loc, atomPtr, std::nullopt, retVal, + rmwMask); + b.barrier(triton::gpu::AddrSpace::Local); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + return success(); + } + + SmallVector resultVals(elemsPerThread); + for (size_t i = 0; i < elemsPerThread; ++i) { + if (auto canonical = getCanonicalIndex(i, regMask); canonical != i) { + resultVals[i] = resultVals[canonical]; + continue; + } + Value rmwPtr = ptrElements[i]; + Value rmwVal = valElements[i]; + Value rmwMask = maskElements.empty() ? b.true_val() : maskElements[i]; + rmwMask = maybeAnd(rewriter, loc, rmwMask, threadPred); + + auto emitAtomic = [&]() -> Value { + if (*maybeKind == LLVM::AtomicBinOp::fadd && + (valueElemTy.isF16() || valueElemTy.isF32() || + valueElemTy.isF64())) { + StringRef funcName; + Type fpType = valueElemTy; + if (valueElemTy.isF16()) { + funcName = "__mt_atomicAdd_f16"; + } else if (valueElemTy.isF32()) { + funcName = "__mt_atomicAdd_f32"; + } else { + funcName = "__mt_atomicAdd_f64"; + } + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto funcType = LLVM::LLVMFunctionType::get(fpType, {ptrTy, fpType}); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + Value addressCast = + LLVM::AddrSpaceCastOp::create(rewriter, loc, ptrTy, rmwPtr); + return LLVM::CallOp::create(rewriter, loc, funcOp, + ValueRange{addressCast, rmwVal}) + .getResult(); + } + return LLVM::AtomicRMWOp::create(rewriter, loc, *maybeKind, rmwPtr, + rmwVal, atomicOrdering) + .getResult(); + }; + + Value atom = emitPredicated(rmwMask, valueElemTy, emitAtomic); + resultVals[i] = atom; + } + + finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals, valueElemTy, + b, storePred, targetInfo, getTypeConverter()); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisInfoAnalysis; +}; + +struct SqmmaLocalAllocOpConversion + : public ConvertOpToLLVMPattern { + SqmmaLocalAllocOpConversion(LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc() || !op.getSrc()) + return failure(); + + auto srcTensorTy = dyn_cast(op.getSrc().getType()); + bool isSqmmaAccumulatorSpill = + srcTensorTy && + isa(srcTensorTy.getEncoding()); + + bool isSqmma = triton::musa::hasSqmmaOpIdxAttr(op.getOperation()); + if (!isSqmma && !isSqmmaAccumulatorSpill) + return failure(); + Location loc = op.getLoc(); + auto memDescTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(memDescTy.getElementType()); + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, memDescTy.getRank(), + loc, rewriter); + + if (isSqmmaAccumulatorSpill) { + auto *ctx = op.getContext(); + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + if (failed(lowerLocalAllocSrcToShared(loc, ctx, op.getSrc(), memDescTy, + smemObj, inVals, getTypeConverter(), + rewriter, targetInfo))) { + return failure(); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = op.getContext(); + auto regTy = cast(op.getSrc().getType()); + + auto sharedEnc = dyn_cast( + memDescTy.getEncoding()); + if (!sharedEnc) + return failure(); + auto order = triton::gpu::getOrder(memDescTy); + if (order.size() != 2) + return failure(); + + auto opIdx = triton::musa::getSqmmaOpIdx(op.getOperation()); + if (!opIdx) + return failure(); + unsigned sqmmaOpIdx = static_cast(*opIdx); + + if (sqmmaOpIdx == 0) + targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local); + + unsigned elemBytes = + std::max(1, memDescTy.getElementTypeBitWidth() / 8); + auto swizzleConfig = + resolvePH1TMESwizzleValueConfig(loc, memDescTy, rewriter); + if (failed(swizzleConfig)) + return failure(); + + auto inVals = ::mlir::unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcIndices = + emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, + /*withCTAOffset=*/false); + if (srcIndices.size() != inVals.size()) + return failure(); + auto freeVarMasks = getFreeVariableMasks(regTy); + uint32_t regMask = freeVarMasks.lookup(str_attr("reg")); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); + + auto shape = regTy.getShape(); + Value elemBytesVal = b.i32_val(static_cast(elemBytes)); + Value smemOffsetBytes = b.i32_val(0); + if (auto offAttr = op->getAttrOfType("allocation.offset")) { + smemOffsetBytes = b.i32_val(static_cast(offAttr.getInt())); + } + + Value smemElemBase = b.bitcast(smemBase, ptr_ty(ctx, 3)); + for (auto [idx, coord] : llvm::enumerate(srcIndices)) { + if (!isCanonicalIndex(static_cast(idx), regMask)) + continue; + auto lmsOffsetInElem = triton::musa::linearizePH1TMELinearCoords( + b, coord, shape, order, elemBytes); + if (failed(lmsOffsetInElem)) + return failure(); + Value lmsAddrInByte = + b.add(b.mul(*lmsOffsetInElem, elemBytesVal), smemOffsetBytes); + Value swizzledAddrInByte = applyPH1TMESwizzleToByteAddressValue( + b, lmsAddrInByte, *swizzleConfig); + Value swizzledAddrRelInByte = b.sub(swizzledAddrInByte, smemOffsetBytes); + Value swizzledElemOffset = b.udiv(swizzledAddrRelInByte, elemBytesVal); + + Value ptr = b.gep(smemElemBase.getType(), llvmElemTy, smemElemBase, + swizzledElemOffset); + targetInfo.storeDShared(rewriter, loc, ptr, std::nullopt, inVals[idx], + threadPred ? threadPred : b.true_val()); + } + + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; +}; + +struct DotOperandLocalLoadOpConversion + : public ConvertOpToLLVMPattern { + DotOperandLocalLoadOpConversion( + LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Keep SQMMA local_load paths on the generic lowering. + if (triton::musa::hasSqmmaOpIdxAttr(op.getOperation())) + return failure(); + + Location loc = op.getLoc(); + auto *ctx = op.getContext(); + auto memDescTy = dyn_cast(op.getSrc().getType()); + if (!memDescTy) + return failure(); + + auto regTy = dyn_cast(op.getResult().getType()); + if (!regTy) + return failure(); + auto dotEnc = dyn_cast(regTy.getEncoding()); + if (!dotEnc || !isa(dotEnc.getParent())) + return failure(); + + auto typeConverter = getTypeConverter(); + auto fmaLoad = lowerResidualFMAOperandLoad(op, adaptor.getSrc(), memDescTy, + regTy, dotEnc, typeConverter, + rewriter, targetInfo); + if (succeeded(fmaLoad)) { + rewriter.replaceOp(op, *fmaLoad); + return success(); + } + + std::optional maxVecElems = + getResidualDotOperandLocalLoadMaxVecElems(memDescTy); + if (!maxVecElems) + return failure(); + + auto sharedEnc = + dyn_cast(memDescTy.getEncoding()); + if (!sharedEnc) + return failure(); + + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kOffset = str_attr("offset"); + auto kBlock = str_attr("block"); + auto regLayout = toLinearLayout(regTy); + + LinearLayout cvt = LinearLayout::empty(); + if (auto paddedEnc = + dyn_cast(sharedEnc)) { + const auto &sharedLL = paddedEnc.getLinearComponent(); + cvt = regLayout.invertAndCompose(sharedLL); + } else { + auto sharedLayout = toLinearLayout(memDescTy); + cvt = regLayout.invertAndCompose(sharedLayout); + } + if (!cvt.isTrivialOver({kBlock})) + return failure(); + cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset}); + + Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy); + uint64_t maskSpanAffineOffset = smemObj.getMaskSpanOffsets(memDescTy); + SmallVector> paddingShifts; + if (auto paddedEnc = dyn_cast( + memDescTy.getEncoding())) { + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + paddingShifts = getPaddedSharedShifts(paddedEnc, bitwidth, + /*offsetInBytes=*/true); + } + + SmallVector outVals = lowerLdStShared( + loc, ctx, cvt, /*valsArray=*/{}, llvmElemTy, smemObj.getBase(), + paddingShifts, affineOffset, maskSpanAffineOffset, rewriter, targetInfo, + maxVecElems, op.getOperation()); + Value result = + ::mlir::packLLElements(loc, typeConverter, outVals, rewriter, regTy); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; +}; + +struct AsyncCopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern { + AsyncCopyGlobalToLocalOpConversion( + LLVMTypeConverter &converter, + const mlir::triton::MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo), + axisInfoAnalysis(axisInfoAnalysis) {} + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto dstTy = cast(op.getResult().getType()); + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + StringRef memcpyIntrinsic = getMusaMemcpyG2SIntrinsic(llvmElemTy); + if (memcpyIntrinsic.empty()) + return op.emitError("async_copy unsupported element bitwidth for " + "llvm.musa.memcpy.g2s.* intrinsic"); + + Value llDst = adaptor.getResult(); + Value llSrc = adaptor.getSrc(); + Value llMask = adaptor.getMask(); + + auto srcElems = ::mlir::unpackLLElements(loc, llSrc, rewriter); + SmallVector maskElems; + if (llMask) + maskElems = ::mlir::unpackLLElements(loc, llMask, rewriter); + + auto ptrTy = srcElems.front().getType(); + auto structTy = + LLVM::LLVMStructType::getLiteral(ctx, ArrayRef{ptrTy, i1_ty}); + + SmallVector vals; + vals.reserve(srcElems.size()); + for (size_t i = 0; i < srcElems.size(); ++i) { + Value packed = LLVM::UndefOp::create(rewriter, loc, structTy); + packed = b.insert_val(packed, srcElems[i], 0); + Value maskElem = llMask ? maskElems[i] : b.true_val(); + packed = b.insert_val(packed, maskElem, 1); + vals.push_back(packed); + } + + auto srcLayout = toLinearLayout(srcTy); + auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout); + srcLayout = removeBroadcastSrc.apply(srcLayout); + vals = removeBroadcastSrc.apply(vals); + + unsigned maxVec = getVectorSize(op.getSrc(), axisInfoAnalysis); + if (op.getMask()) + maxVec = + std::min(maxVec, axisInfoAnalysis.getMaskAlignment(op.getMask())); + maxVec = std::max(maxVec, op.getContiguity()); + maxVec = std::max(1u, maxVec); + int vecBytes = maxVec * llvmElemTy.getIntOrFloatBitWidth() / 8; + if (vecBytes < 4) { + return op.emitError( + "async_copy does not support transfers smaller than 4 bytes") + << "; calculated " << vecBytes << " bytes"; + } + + Value threadPred = b.true_val(); + + bool usePred = false; + if (op.getMask()) { + Operation *maskOp = op.getMask().getDefiningOp(); + usePred = !isa_and_nonnull(maskOp); + } + + if (requiresAbsoluteSwizzledAsyncCopy(dstTy)) { + auto sharedEnc = dyn_cast( + dstTy.getEncoding()); + auto order = triton::gpu::getOrder(dstTy); + if (!sharedEnc || order.size() != 2) + return op.emitError( + "PH1 swizzled async_copy expects 2D swizzled shared destination"); + + unsigned elemBytes = + std::max(1, dstTy.getElementTypeBitWidth() / 8); + auto swizzleConfig = + resolvePH1TMESwizzleValueConfig(loc, dstTy, rewriter); + if (failed(swizzleConfig)) + return failure(); + + auto srcIndices = + emitIndices(loc, rewriter, targetInfo, srcTy.getEncoding(), srcTy, + /*withCTAOffset=*/false); + if (srcIndices.size() != srcElems.size()) + return failure(); + + int64_t strideRow64 = srcTy.getShape()[order[0]]; + if (strideRow64 <= 0 || strideRow64 > std::numeric_limits::max()) + return failure(); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, llvmElemTy, rewriter); + Value smemObjBase = smemObj.getBase(); + Value smemElemBase = b.bitcast(smemObjBase, ptr_ty(ctx, 3)); + auto rootLocalAlloc = findRootLocalAlloc(op.getResult()); + if (!rootLocalAlloc || !rootLocalAlloc->hasAttr("allocation.offset")) { + return op.emitError("PH1 swizzled async_copy requires root local_alloc " + "with allocation.offset"); + } + Value smemRawBase = LLVM::getSharedMemoryBase( + loc, rewriter, targetInfo, rootLocalAlloc.getOperation()); + auto rootOffsetAttr = + rootLocalAlloc->getAttrOfType("allocation.offset"); + Value rootOffsetBytes = + b.i32_val(static_cast(rootOffsetAttr.getInt())); + Value smemOffsetFromRoot = b.sub(b.ptrtoint(i32_ty, smemObjBase), + b.ptrtoint(i32_ty, smemRawBase)); + Value smemOffsetBytes = b.add(smemOffsetFromRoot, rootOffsetBytes); + + Value elemBytesVal = b.i32_val(static_cast(elemBytes)); + + unsigned inVec = std::max(1, op.getContiguity()); + unsigned outVec = sharedEnc.getVec(); + unsigned minVec = inVec; + if (outVec > 1) + minVec = std::min(outVec, inVec); + unsigned numElems = getTotalElemsPerThread(srcTy); + + auto shape = srcTy.getShape(); + + for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { + auto idx = srcIndices[elemIdx]; + auto lmsOffsetInElem = triton::musa::linearizePH1TMELinearCoords( + b, idx, shape, order, elemBytes); + if (failed(lmsOffsetInElem)) + return failure(); + + Value lmsAddrInByte = + b.add(b.mul(*lmsOffsetInElem, elemBytesVal), smemOffsetBytes); + Value swizzledAddrInByte = applyPH1TMESwizzleToByteAddressValue( + b, lmsAddrInByte, *swizzleConfig); + Value swizzledAddrRelInByte = + b.sub(swizzledAddrInByte, smemOffsetBytes); + Value swizzledElemOffset = b.udiv(swizzledAddrRelInByte, elemBytesVal); + Value basePtr = b.gep(smemElemBase.getType(), llvmElemTy, smemElemBase, + swizzledElemOffset); + + auto maxBitWidth = + std::max(128, llvmElemTy.getIntOrFloatBitWidth()); + auto vecBitWidth = llvmElemTy.getIntOrFloatBitWidth() * minVec; + auto bitWidth = std::min(maxBitWidth, vecBitWidth); + auto numWords = vecBitWidth / bitWidth; + auto numWordElems = bitWidth / llvmElemTy.getIntOrFloatBitWidth(); + auto byteWidth = bitWidth / 8; + auto resByteWidth = llvmElemTy.getIntOrFloatBitWidth() / 8; + + for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) { + unsigned wordElemIdx = wordIdx * numWordElems; + unsigned offset = wordElemIdx * resByteWidth; + Value packedVal = vals[elemIdx + wordElemIdx]; + Value srcPtr = b.extract_val(ptrTy, packedVal, 0); + Value maskElem = b.extract_val(i1_ty, packedVal, 1); + Value dst = b.gep(basePtr.getType(), llvmElemTy, basePtr, + b.i32_val(static_cast(offset))); + Value copyPred = b.true_val(); + if (usePred) { + Value notMask = b.xor_(maskElem, b.true_val()); + Value zeroPred = notMask; + Value zeroElem = LLVM::ConstantOp::create( + rewriter, loc, llvmElemTy, rewriter.getZeroAttr(llvmElemTy)); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value dstBase = b.bitcast(dst, elemPtrTy); + for (unsigned elem = 0; elem < numWordElems; ++elem) { + Value dstPtr = + b.gep(elemPtrTy, llvmElemTy, dstBase, b.i32_val(elem)); + LLVM::MUSA::llStore(rewriter, loc, dstPtr, zeroElem, zeroPred); + } + copyPred = b.and_(copyPred, maskElem); + } + + Value cpSize = b.i32_val(static_cast(byteWidth)); + Value prefetchSize = b.i32_val(0); + emitIfPredicated(rewriter, loc, copyPred, [&]() { + auto funcType = LLVM::LLVMFunctionType::get( + void_ty(ctx), + {dst.getType(), srcPtr.getType(), cpSize.getType(), + prefetchSize.getType()}, + /*isVarArg=*/false); + auto funcOp = appendOrGetExternFuncOp(rewriter, op, memcpyIntrinsic, + funcType); + LLVM::CallOp::create(rewriter, loc, funcOp, + ValueRange{dst, srcPtr, cpSize, prefetchSize}); + }); + } + } + + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } + + auto emitAsyncCopy = [&](RewriterBase &rewriter, Location emitLoc, + ArrayRef values, Value shmemAddr, + int startIdx, + VectorType vecTy) -> SmallVector { + auto tb = TritonLLVMOpBuilder(emitLoc, rewriter); + unsigned elemsPerVec = vecTy.getNumElements(); + unsigned byteWidth = elemsPerVec * llvmElemTy.getIntOrFloatBitWidth() / 8; + + Value packedVal = values[startIdx]; + Value srcPtr = tb.extract_val(ptrTy, packedVal, 0); + Value maskElem = tb.extract_val(i1_ty, packedVal, 1); + Value copyPred = threadPred ? threadPred : tb.true_val(); + + if (usePred) { + Value notMask = tb.xor_(maskElem, tb.true_val()); + Value zeroPred = threadPred ? tb.and_(threadPred, notMask) : notMask; + Value zeroElem = LLVM::ConstantOp::create( + rewriter, emitLoc, llvmElemTy, rewriter.getZeroAttr(llvmElemTy)); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value dstBase = tb.bitcast(shmemAddr, elemPtrTy); + for (unsigned elem = 0; elem < elemsPerVec; ++elem) { + Value dstPtr = + tb.gep(elemPtrTy, llvmElemTy, dstBase, tb.i32_val(elem)); + LLVM::MUSA::llStore(rewriter, emitLoc, dstPtr, zeroElem, zeroPred); + } + copyPred = tb.and_(copyPred, maskElem); + } + + Value cpSize = tb.i32_val(byteWidth); + Value prefetchSize = tb.i32_val(0); + emitIfPredicated(rewriter, emitLoc, copyPred, [&]() { + auto funcType = LLVM::LLVMFunctionType::get( + void_ty(ctx), + {shmemAddr.getType(), srcPtr.getType(), cpSize.getType(), + prefetchSize.getType()}, + /*isVarArg=*/false); + auto funcOp = + appendOrGetExternFuncOp(rewriter, op, memcpyIntrinsic, funcType); + LLVM::CallOp::create( + rewriter, emitLoc, funcOp, + ValueRange{shmemAddr, srcPtr, cpSize, prefetchSize}); + }); + return {}; + }; + + auto smemObj = + LLVM::getSharedMemoryObjectFromStruct(loc, llDst, llvmElemTy, rewriter); + auto smemLayout = toLinearLayout(dstTy); + auto cvt = srcLayout.invertAndCompose(smemLayout); + if (!cvt.isTrivialOver({str_attr("block")})) + return op.emitError( + "async_copy does not support non-trivial block dimension"); + cvt = cvt.sublayout( + {str_attr("register"), str_attr("lane"), str_attr("warp")}, + {str_attr("offset")}); + + Value affineOffset = smemObj.getShmemOffset(loc, rewriter, dstTy); + uint64_t maskSpanAffineOffset = smemObj.getMaskSpanOffsets(dstTy); + std::optional maybeMaxVecElems; + SmallVector> paddingShifts; + if (auto paddedEnc = dyn_cast( + dstTy.getEncoding())) { + maybeMaxVecElems = paddedEnc.getMinInterval(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + paddingShifts = + getPaddedSharedShifts(paddedEnc, bitwidth, /*offsetInBytes=*/true); + } + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + lowerLdSt(loc, ctx, cvt, vals, llvmElemTy, smemObj.getBase(), paddingShifts, + affineOffset, maskSpanAffineOffset, laneId, warpId, rewriter, + targetInfo, maxVec, emitAsyncCopy); + + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } + +private: + const mlir::triton::MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisInfoAnalysis; +}; + +struct AsyncCommitGroupOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::AsyncCommitGroupOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + +struct AsyncWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::AsyncWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + TritonLLVMOpBuilder b(loc, rewriter); + auto waitFnTy = LLVM::LLVMFunctionType::get(void_ty(ctx), {}, false); + auto waitFn = appendOrGetExternFuncOp( + rewriter, op, "llvm.musa.memcpy.g2s.wait", waitFnTy); + LLVM::CallOp::create(rewriter, loc, waitFn, ValueRange{}); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateLoadStoreOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + int /*computeCapability*/, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, PatternBenefit(benefit.getBenefit() + 2)); + patterns.add( + typeConverter, targetInfo, PatternBenefit(benefit.getBenefit() + 3)); + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add( + typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/MUSAOpsToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/MUSAOpsToLLVM.cpp new file mode 100644 index 0000000000..66336c4d7e --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/MUSAOpsToLLVM.cpp @@ -0,0 +1,741 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "DotOpToLLVM/DotOpToLLVM.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +using namespace mlir; + +namespace { + +StringRef getTMELoadIntrinsicName(unsigned rank) { + switch (rank) { + case 1: + return "llvm.musa.tme.ld.tile.1d"; + case 2: + return "llvm.musa.tme.ld.tile.2d"; + case 3: + return "llvm.musa.tme.ld.tile.3d"; + case 4: + return "llvm.musa.tme.ld.tile.4d"; + case 5: + return "llvm.musa.tme.ld.tile.5d"; + default: + return {}; + } +} + +StringRef getTMEStoreIntrinsicName(unsigned rank) { + switch (rank) { + case 1: + return "llvm.musa.tme.st.1d"; + case 2: + return "llvm.musa.tme.st.2d"; + case 3: + return "llvm.musa.tme.st.3d"; + case 4: + return "llvm.musa.tme.st.4d"; + case 5: + return "llvm.musa.tme.st.5d"; + default: + return {}; + } +} + +Value normalizeTMEDescriptorAddr(Value value, Type srcType, Location loc, + ConversionPatternRewriter &rewriter) { + if (srcType.isInteger(64)) + return value; + if (isa(srcType)) { + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), + value); + } + if (isa(value.getType())) { + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), + value); + } + return value; +} + +Value normalizeTMESharedPtr(Value value, Type srcType, Type elemType, + Location loc, ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter) { + if (auto memDescTy = dyn_cast(srcType)) { + Type llvmElemTy = typeConverter->convertType(elemType); + auto memObj = + LLVM::getSharedMemoryObjectFromStruct(loc, value, llvmElemTy, rewriter); + return memObj.getShmemAffineBase(loc, rewriter, memDescTy); + } + return value; +} + +template +Value materializeTMEEnumAttr(Location loc, AttrT attr, + ConversionPatternRewriter &rewriter) { + return arith::ConstantIntOp::create( + rewriter, loc, static_cast(attr.getValue()), 32); +} + +Value reverseTMEVector(Value value, unsigned rank, Location loc, + ConversionPatternRewriter &rewriter) { + if (rank <= 1) + return value; + + auto vecTy = dyn_cast(value.getType()); + if (!vecTy || vecTy.getNumElements() != rank || + !vecTy.getElementType().isInteger(32)) + return value; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector elems; + elems.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + elems.push_back(b.extract_element(value, b.i32_val(i))); + + Value reversed = b.undef(vecTy); + for (unsigned i = 0; i < rank; ++i) + reversed = b.insert_element(reversed, elems[rank - i - 1], b.i32_val(i)); + return reversed; +} + +Value materializeTMECoord(Location loc, ValueRange coord, + ConversionPatternRewriter &rewriter) { + if (coord.empty()) + return {}; + if (coord.size() == 1) + return triton::musa::materializeI32Value(coord.front(), loc, rewriter); + + SmallVector elems; + elems.reserve(coord.size()); + for (Value value : coord) { + Value i32Value = triton::musa::materializeI32Value(value, loc, rewriter); + if (!i32Value) + return {}; + elems.push_back(i32Value); + } + + auto vecTy = VectorType::get({static_cast(coord.size())}, + rewriter.getI32Type()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value coordVector = b.undef(vecTy); + for (unsigned i = 0; i < elems.size(); ++i) + coordVector = b.insert_element(vecTy, coordVector, elems[i], b.i32_val(i)); + return reverseTMEVector(coordVector, coord.size(), loc, rewriter); +} + +template +Value materializeTMEBlockShape(Location loc, ArrayRef blockShape, + ConversionPatternRewriter &rewriter) { + if (blockShape.empty()) + return {}; + if (blockShape.size() == 1) + return arith::ConstantIntOp::create( + rewriter, loc, static_cast(blockShape.front()), 32); + + SmallVector elems; + elems.reserve(blockShape.size()); + for (IntT dim : blockShape) + elems.push_back(arith::ConstantIntOp::create( + rewriter, loc, static_cast(dim), 32)); + + auto vecTy = VectorType::get({static_cast(blockShape.size())}, + rewriter.getI32Type()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value blockShapeVector = b.undef(vecTy); + for (unsigned i = 0; i < elems.size(); ++i) + blockShapeVector = + b.insert_element(vecTy, blockShapeVector, elems[i], b.i32_val(i)); + return reverseTMEVector(blockShapeVector, blockShape.size(), loc, rewriter); +} + +std::optional getI32Constant(Value value) { + Attribute attr; + if (!matchPattern(value, m_Constant(&attr))) + return std::nullopt; + + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + + if (auto splatAttr = dyn_cast(attr)) { + auto intAttr = dyn_cast(splatAttr.getSplatValue()); + if (intAttr) + return intAttr.getInt(); + } + + return std::nullopt; +} + +std::optional getPositiveIntAttrFromParents(Operation *op, + StringRef name) { + for (Operation *cur = op; cur; cur = cur->getParentOp()) { + if (auto attr = cur->getAttrOfType(name)) { + if (attr.getInt() > 0) + return attr.getInt(); + } + } + return std::nullopt; +} + +std::optional inferRowMajorFromMemDesc(Type type) { + auto memDescTy = dyn_cast(type); + if (!memDescTy) + return std::nullopt; + auto order = triton::gpu::getOrder(memDescTy); + if (order.empty()) + return std::nullopt; + return static_cast(order.front() + 1) == + static_cast(memDescTy.getShape().size()); +} + +std::optional inferElemBytesFromMemDesc(Type type) { + auto memDescTy = dyn_cast(type); + if (!memDescTy) + return std::nullopt; + int bitWidth = memDescTy.getElementTypeBitWidth(); + if (bitWidth <= 0) + return std::nullopt; + return static_cast((bitWidth + 7) / 8); +} + +Value buildTMEIssuePredicate(Value userPred, Location loc, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadId = getThreadId(rewriter, loc); + Value issuerPred = b.icmp_eq(threadId, b.i32_val(0)); + return b.and_(userPred, issuerPred); +} + +Value buildTMEIssueOnlyPredicate(Location loc, + ConversionPatternRewriter &rewriter) { + Value truePred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + return buildTMEIssuePredicate(truePred, loc, rewriter); +} + +void emitPredicatedVoidIntrinsic(ConversionPatternRewriter &rewriter, + Location loc, Value pred, StringRef intrinsic, + ArrayRef operands) { + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterCall = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterCall); + rewriter.setInsertionPointToEnd(currentBlock); + LLVM::CondBrOp::create(rewriter, loc, pred, trueBlock, afterCall); + rewriter.setInsertionPointToStart(trueBlock); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, TypeRange{}, + operands); + LLVM::BrOp::create(rewriter, loc, afterCall); + rewriter.setInsertionPointToStart(afterCall); +} + +struct TMELoadSegment { + Value dstAddr; + Value blockDim; + Value blockPos; +}; + +SmallVector +buildTMELoadSegments(triton::musa::AsyncTMECopyGlobalToLocalOp op, + std::optional + recoveredContract, + Value dstAddr, Value blockDim, Value blockPos, + Location loc, ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter) { + SmallVector segments; + auto appendSegment = [&](Value segDstAddr, Value segBlockDim, + Value segBlockPos) { + segments.push_back(TMELoadSegment{segDstAddr, segBlockDim, segBlockPos}); + }; + + if (!recoveredContract) { + appendSegment(dstAddr, blockDim, blockPos); + return segments; + } + + auto contract = *recoveredContract; + auto memDescTy = dyn_cast(op.getResult().getType()); + if (!memDescTy) + return segments; + if (memDescTy.getShape().size() != 2) + return segments; + + auto shape = memDescTy.getShape(); + auto order = triton::musa::getSharedOrder(memDescTy.getEncoding(), + memDescTy.getShape()); + if (order.empty()) + return segments; + + auto maybeElemBytes = triton::musa::inferElemBytesFromMemDesc(memDescTy); + if (!maybeElemBytes || *maybeElemBytes <= 0 || + *maybeElemBytes != contract.elemBytes) + return segments; + + int64_t majorDimIdx = static_cast(order.front()); + int64_t minorDimIdx = majorDimIdx == 0 ? 1 : 0; + int64_t leading = shape[majorDimIdx]; + int64_t leadingWidthBytes = leading * *maybeElemBytes; + if (leadingWidthBytes <= 256) { + appendSegment(dstAddr, blockDim, blockPos); + return segments; + } + + auto vecTy = dyn_cast(blockDim.getType()); + if (!vecTy || vecTy.getNumElements() < 2 || + !vecTy.getElementType().isInteger(32)) + return segments; + + int64_t vectorRank = vecTy.getNumElements(); + int64_t majorVectorIdx = vectorRank - majorDimIdx - 1; + if (majorVectorIdx < 0 || majorVectorIdx >= vectorRank) + return segments; + int64_t maxLeadingElems = 256 / *maybeElemBytes; + if (maxLeadingElems <= 0) { + appendSegment(dstAddr, blockDim, blockPos); + return segments; + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value majorVectorIdxVal = b.i32_val(static_cast(majorVectorIdx)); + Value majorBlockPos = b.extract_element(blockPos, majorVectorIdxVal); + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + + int64_t leadingOffset = 0; + while (leadingOffset < leading) { + int64_t groupLeading = + std::min(leading - leadingOffset, maxLeadingElems); + Value groupBlockDim = + b.insert_element(blockDim, b.i32_val(groupLeading), majorVectorIdxVal); + Value groupMajorBlockPos = b.add(majorBlockPos, b.i32_val(leadingOffset)); + Value groupBlockPos = + b.insert_element(blockPos, groupMajorBlockPos, majorVectorIdxVal); + + Value groupDstAddr = dstAddr; + if (leadingOffset != 0) { + int64_t tileElemOffset = shape[minorDimIdx] * leadingOffset; + groupDstAddr = + b.gep(elemPtrTy, llvmElemTy, dstAddr, b.i32_val(tileElemOffset)); + } + + appendSegment(groupDstAddr, groupBlockDim, groupBlockPos); + leadingOffset += groupLeading; + } + + return segments; +} + +struct SquadDotOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::SquadDotOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::SquadDotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value threadId = getThreadId(rewriter, loc); + if (failed(mlir::triton::MUSA::convertSQMMADot( + op, adaptor, this->getTypeConverter(), rewriter, threadId))) + return op.emitError("MUSA SQMMA: ttmg direct lowering failed"); + return success(); + } +}; + +struct SquadDotWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::SquadDotWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::SquadDotWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::createLLVMIntrinsicCallOp(rewriter, op.getLoc(), + "llvm.musa.sqmma.wait", TypeRange{}, {}); + rewriter.replaceOp(op, adaptor.getInputs()); + return success(); + } +}; + +struct WmmaDotOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::WmmaDotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(mlir::triton::MUSA::convertWMMADot( + op, adaptor, this->getTypeConverter(), rewriter))) + return op.emitError("MUSA WMMA: ttmg direct lowering failed"); + return success(); + } +}; + +struct WmmaDotWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::WmmaDotWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::WmmaDotWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + rewriter.eraseOp(op); + return success(); + } +}; + +struct BarRecordOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::BarRecordOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::BarRecordOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands = {adaptor.getBarId()}; + LLVM::createLLVMIntrinsicCallOp(rewriter, op.getLoc(), + "llvm.musa.async.bar.record", TypeRange{}, + operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct InitArrivalOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::InitArrivalOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::InitArrivalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value arriveCount = adaptor.getArriveCount(); + if (auto count = getI32Constant(arriveCount); count && *count <= 0) { + auto numWarps = getPositiveIntAttrFromParents( + op.getOperation(), triton::gpu::AttrNumWarpsName); + if (numWarps) + arriveCount = TritonLLVMOpBuilder(loc, rewriter) + .i32_val(static_cast(*numWarps)); + } + + Value launchPred = buildTMEIssueOnlyPredicate(loc, rewriter); + SmallVector operands = {adaptor.getBarId(), arriveCount, + adaptor.getPhaseId()}; + emitPredicatedVoidIntrinsic(rewriter, loc, launchPred, + "llvm.musa.async.init.arrival", operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct BarrierAddTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::BarrierAddTransOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::BarrierAddTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value launchPred = buildTMEIssuePredicate(adaptor.getPred(), loc, rewriter); + SmallVector operands = {adaptor.getBarId(), adaptor.getTransBytes()}; + emitPredicatedVoidIntrinsic(rewriter, loc, launchPred, + "llvm.musa.async.add.trans", operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ArriveBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::ArriveBarrierOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::ArriveBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands = {adaptor.getBarId()}; + auto call = LLVM::createLLVMIntrinsicCallOp( + rewriter, op.getLoc(), "llvm.musa.async.arrive", + TypeRange{op.getResult().getType()}, operands); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +struct ArriveBarrierNoRetOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::ArriveBarrierNoRetOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::ArriveBarrierNoRetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value launchPred = buildTMEIssuePredicate(adaptor.getPred(), loc, rewriter); + SmallVector operands = {adaptor.getBarId()}; + emitPredicatedVoidIntrinsic(rewriter, loc, launchPred, + "llvm.musa.async.arrive.none.phaseid", + operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct WaitBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::WaitBarrierOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::WaitBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands = {adaptor.getBarId(), adaptor.getPhaseId()}; + LLVM::createLLVMIntrinsicCallOp( + rewriter, op.getLoc(), "llvm.musa.async.wait", TypeRange{}, operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct AsyncTMECopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::AsyncTMECopyGlobalToLocalOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::AsyncTMECopyGlobalToLocalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto coordRank = adaptor.getCoord().size(); + auto blockShape = op.getBlockShape(); + if (coordRank == 0 || coordRank > 5 || blockShape.size() != coordRank) { + return op.emitError( + "MUSA async_tme_copy_global_to_local expects coord/blockShape rank " + "in [1, 5] to match"); + } + StringRef intrinsic = + getTMELoadIntrinsicName(static_cast(coordRank)); + if (intrinsic.empty()) + return op.emitError( + "MUSA async_tme_copy_global_to_local unsupported rank"); + + Value blockDim = materializeTMEBlockShape(loc, blockShape, rewriter); + Value blockPos = materializeTMECoord(loc, adaptor.getCoord(), rewriter); + if (!blockDim || !blockPos) + return op.emitError("unable to materialize TME coord/blockShape"); + + Value dstAddr = adaptor.getResult(); + if (auto memDescTy = + dyn_cast(op.getResult().getType())) { + dstAddr = normalizeTMESharedPtr(adaptor.getResult(), memDescTy, + memDescTy.getElementType(), loc, rewriter, + this->getTypeConverter()); + } + Value descAddr = normalizeTMEDescriptorAddr( + adaptor.getDesc(), op.getDesc().getType(), loc, rewriter); + + auto sgAttr = op->getAttrOfType( + "swizzleGranularity"); + auto ssAttr = + op->getAttrOfType("swizzleStride"); + auto slAttr = + op->getAttrOfType("swizzleLine"); + auto prefetchAttr = + op->getAttrOfType("prefetchSize"); + auto cacheAttr = + op->getAttrOfType("cachePolicy"); + auto innerAttr = + op->getAttrOfType("innerPersistence"); + auto outerAttr = + op->getAttrOfType("outerPersistence"); + if (!sgAttr || !ssAttr || !slAttr || !prefetchAttr || !cacheAttr || + !innerAttr || !outerAttr) + return op.emitError("missing typed TME policy attrs"); + + Value swizzleGranularity = materializeTMEEnumAttr(loc, sgAttr, rewriter); + Value swizzleStride = materializeTMEEnumAttr(loc, ssAttr, rewriter); + Value swizzleLine = materializeTMEEnumAttr(loc, slAttr, rewriter); + Value prefetchSize = materializeTMEEnumAttr(loc, prefetchAttr, rewriter); + Value innerPersistence = materializeTMEEnumAttr(loc, innerAttr, rewriter); + Value outerPersistence = materializeTMEEnumAttr(loc, outerAttr, rewriter); + Value cachePolicy = materializeTMEEnumAttr(loc, cacheAttr, rewriter); + + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.barrier0", + TypeRange{}, {}); + + Value launchPred = buildTMEIssuePredicate(adaptor.getPred(), loc, rewriter); + auto recoveredContract = + triton::musa::recoverAndVerifyGroupedTMELoadConsumerContract(op); + if (failed(recoveredContract)) + return failure(); + auto loadSegments = + buildTMELoadSegments(op, *recoveredContract, dstAddr, blockDim, + blockPos, loc, rewriter, this->getTypeConverter()); + if (loadSegments.empty()) { + return op.emitError("unable to materialize grouped TME load segments"); + } + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterCall = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterCall); + rewriter.setInsertionPointToEnd(currentBlock); + LLVM::CondBrOp::create(rewriter, loc, launchPred, trueBlock, afterCall); + rewriter.setInsertionPointToStart(trueBlock); + for (const auto &segment : loadSegments) { + SmallVector operands = { + adaptor.getBarId(), segment.dstAddr, descAddr, + segment.blockDim, segment.blockPos, swizzleGranularity, + swizzleStride, swizzleLine, prefetchSize, + innerPersistence, outerPersistence, cachePolicy}; + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, TypeRange{}, + operands); + } + LLVM::BrOp::create(rewriter, loc, afterCall); + rewriter.setInsertionPointToStart(afterCall); + rewriter.eraseOp(op); + return success(); + } +}; + +struct AsyncTMECopyLocalToGlobalOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::AsyncTMECopyLocalToGlobalOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::AsyncTMECopyLocalToGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto coordRank = adaptor.getCoord().size(); + auto blockShape = op.getBlockShape(); + if (coordRank == 0 || coordRank > 5 || blockShape.size() != coordRank) { + return op.emitError( + "MUSA async_tme_copy_local_to_global expects coord/blockShape rank " + "in [1, 5] to match"); + } + StringRef intrinsic = + getTMEStoreIntrinsicName(static_cast(coordRank)); + if (intrinsic.empty()) + return op.emitError( + "MUSA async_tme_copy_local_to_global unsupported rank"); + + Value blockDim = materializeTMEBlockShape(loc, blockShape, rewriter); + Value blockPos = materializeTMECoord(loc, adaptor.getCoord(), rewriter); + if (!blockDim || !blockPos) + return op.emitError("unable to materialize TME coord/blockShape"); + + Value srcAddr = adaptor.getSrc(); + if (auto memDescTy = + dyn_cast(op.getSrc().getType())) { + srcAddr = normalizeTMESharedPtr(adaptor.getSrc(), memDescTy, + memDescTy.getElementType(), loc, rewriter, + this->getTypeConverter()); + } + Value descAddr = normalizeTMEDescriptorAddr( + adaptor.getDesc(), op.getDesc().getType(), loc, rewriter); + + auto sgAttr = op->getAttrOfType( + "swizzleGranularity"); + auto ssAttr = + op->getAttrOfType("swizzleStride"); + auto slAttr = + op->getAttrOfType("swizzleLine"); + auto cacheAttr = + op->getAttrOfType("cachePolicy"); + auto innerAttr = + op->getAttrOfType("innerPersistence"); + auto outerAttr = + op->getAttrOfType("outerPersistence"); + if (!sgAttr || !ssAttr || !slAttr || !cacheAttr || !innerAttr || !outerAttr) + return op.emitError("missing typed TME policy attrs"); + + Value swizzleGranularity = materializeTMEEnumAttr(loc, sgAttr, rewriter); + Value swizzleStride = materializeTMEEnumAttr(loc, ssAttr, rewriter); + Value swizzleLine = materializeTMEEnumAttr(loc, slAttr, rewriter); + Value innerPersistence = materializeTMEEnumAttr(loc, innerAttr, rewriter); + Value outerPersistence = materializeTMEEnumAttr(loc, outerAttr, rewriter); + Value cachePolicy = materializeTMEEnumAttr(loc, cacheAttr, rewriter); + Value launchPred = buildTMEIssuePredicate(adaptor.getPred(), loc, rewriter); + SmallVector operands = { + srcAddr, descAddr, blockDim, + blockPos, swizzleGranularity, swizzleStride, + swizzleLine, innerPersistence, outerPersistence, + cachePolicy}; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterCall = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterCall); + rewriter.setInsertionPointToEnd(currentBlock); + LLVM::CondBrOp::create(rewriter, loc, launchPred, trueBlock, afterCall); + rewriter.setInsertionPointToStart(trueBlock); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, TypeRange{}, + operands); + LLVM::BrOp::create(rewriter, loc, afterCall); + rewriter.setInsertionPointToStart(afterCall); + rewriter.eraseOp(op); + return success(); + } +}; + +struct TMEStoreCommitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::TMEStoreCommitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::TMEStoreCommitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + LLVM::createLLVMIntrinsicCallOp( + rewriter, op.getLoc(), "llvm.musa.tme.store.commit", TypeRange{}, {}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct TMEStoreReadWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::musa::TMEStoreReadWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::musa::TMEStoreReadWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + LLVM::createLLVMIntrinsicCallOp(rewriter, op.getLoc(), + "llvm.musa.tme.store.read.wait", + TypeRange{}, {}); + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateMUSAOpsToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns + .add(typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 0000000000..f013b00bf1 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,62 @@ +#ifndef TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H +#define TRITONMUSAGPU_CONVERSION_TRITONMUSAGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H + +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +namespace mlir { +namespace triton { +namespace MUSA { + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateBarrierOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit, + const TargetInfo &targetInfo); + +void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMUSAOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + const TargetInfo &targetInfo, PatternBenefit benefit); + +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + int computeCapability, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateThreadIdOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateWarpIdOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace MUSA +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 0000000000..65513c1c8d --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,49 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct GetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + StringRef intrinsic; + switch (op.getAxis()) { + case ProgramIDDim::X: + intrinsic = "llvm.musa.read.ptx.sreg.nctaid.x"; + break; + case ProgramIDDim::Y: + intrinsic = "llvm.musa.read.ptx.sreg.nctaid.y"; + break; + case ProgramIDDim::Z: + intrinsic = "llvm.musa.read.ptx.sreg.nctaid.z"; + break; + default: + intrinsic = "llvm.musa.read.ptx.sreg.nctaid.x"; + break; + } + + auto call = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i32_ty, {}); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateSPMDOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TargetInfo.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TargetInfo.cpp new file mode 100644 index 0000000000..0ed0253545 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,232 @@ +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallString.h" + +using namespace mlir; +using namespace mlir::triton::MUSA; + +namespace { + +Location getCurrentLoc(RewriterBase &rewriter) { + if (Operation *parent = rewriter.getBlock()->getParentOp()) + return parent->getLoc(); + return UnknownLoc::get(rewriter.getContext()); +} + +LLVM::LLVMFuncOp getPrintfDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + constexpr StringRef funcName("printf"); + if (Operation *funcOp = moduleOp.lookupSymbol(funcName)) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, {ptr_ty(context)}, /*isVarArg=*/true); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + return LLVM::LLVMFuncOp::create(rewriter, UnknownLoc::get(context), funcName, + funcType); +} + +std::pair printfPromoteValue(RewriterBase &rewriter, Value value, + bool isSigned) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + Value newValue = value; + Type newType = type; + auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + newType = i32_ty; + newValue = isSigned ? b.sext(newType, value).getResult() + : b.zext(newType, value).getResult(); + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newValue = b.fpext(newType, value).getResult(); + } + + return {newType, newValue}; +} + +} // namespace + +bool TargetInfo::supportMaximumMinimum() const { return false; } + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + return arith::ConstantIntOp::create(rewriter, loc, 0, 32); +} + +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { + Value pred = cmp; + if (!cmp.getType().isInteger(1)) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (cmp.getType().isIntOrIndex()) { + pred = + b.icmp_ne(cmp, b.int_val(cmp.getType().getIntOrFloatBitWidth(), 0)); + } else { + Value zero = LLVM::ConstantOp::create( + rewriter, loc, cmp.getType(), rewriter.getZeroAttr(cmp.getType())); + pred = LLVM::FCmpOp::create(rewriter, loc, rewriter.getI1Type(), + LLVM::FCmpPredicate::one, cmp, zero); + } + } + // MTGPU backend lowers the sync vote intrinsics; emitting the legacy + // non-sync ballot leaves an unsupported intrinsic in the DAG legalization + // path and can wedge llc in release builds. + return LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.vote.ballot.sync", type, {pred}) + .getResult(0); +} + +void TargetInfo::barrier(Location loc, RewriterBase &rewriter, + triton::gpu::AddrSpace targets) const { + if (targets == triton::gpu::AddrSpace::Local) { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.syncthreads.lm", + TypeRange{}, {}); + } + if (targets != triton::gpu::AddrSpace::Local) { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.barrier0", + TypeRange{}, {}); + } +} + +void TargetInfo::warpSync(Location loc, RewriterBase &rewriter) const { + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.syncwarp", + TypeRange{}, {}); +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + assert(!ctaId.has_value() && "cross-CTA shared stores are not supported"); + LLVM::MUSA::llStore(rewriter, loc, ptr, val, pred); +} + +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred, Operation * /*localLoadOp*/) const { + assert(!ctaId.has_value() && "cross-CTA shared loads are not supported"); + Value falseVal = LLVM::ConstantOp::create(rewriter, loc, elemTy, + rewriter.getZeroAttr(elemTy)); + return LLVM::MUSA::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); +} + +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::MUSA::shuffleXor(loc, rewriter, val, i, 32); +} + +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::MUSA::shuffleUp(loc, rewriter, val, i, 32); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::MUSA::shuffleIdx(loc, rewriter, val, i, 32); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { + return LLVM::MUSA::shuffleIdx(loc, rewriter, val, i, 32); +} + +Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a, + Value b, Value selector) const { + return LLVM::MUSA::permute(loc, rewriter, a, b, selector); +} + +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, ProgramIDDim axis) const { + return LLVM::MUSA::llGetPid(loc, rewriter, moduleOp, axis); +} + +bool TargetInfo::warpReduce(RewriterBase & /*rewriter*/, Location /*loc*/, + SmallVector & /*acc*/, + triton::ReduceOp /*op*/, + unsigned /*numLaneToReduce*/, + unsigned /*interleave*/) const { + return false; +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + (void)resultElementTy; + return ""; +} + +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args, + ArrayRef isSigned) const { + auto funcOp = getPrintfDeclaration(rewriter); + SmallVector operands{formatStrStart}; + for (auto [i, arg] : llvm::enumerate(args)) { + Type newType; + Value newArg; + std::tie(newType, newArg) = printfPromoteValue( + rewriter, arg, isSigned.empty() ? true : isSigned[i]); + (void)newType; + operands.push_back(newArg); + } + LLVM::CallOp::create(rewriter, getCurrentLoc(rewriter), funcOp, operands); +} + +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + ArrayRef isSigned) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, isSigned); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + llvm::SmallString<96> assertFormat( + "[MUSA_KERNEL_ASSERT] %s:%u: %s: Assertion `%s' failed.\n"); + assertFormat.push_back('\0'); + Value formatValue = + LLVM::addStringToModule(loc, rewriter, "assertFormat_", assertFormat); + + llvm::SmallString<64> messageString(message), fileString(file), + funcString(func); + messageString.push_back('\0'); + fileString.push_back('\0'); + funcString.push_back('\0'); + + Value messageValue = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", messageString); + Value fileValue = + LLVM::addStringToModule(loc, rewriter, "assertFile_", fileString); + Value funcValue = + LLVM::addStringToModule(loc, rewriter, "assertFunc_", funcString); + Value lineValue = b.i32_val(line); + + printf(rewriter, formatValue, assertFormat.size_in_bytes(), + {fileValue, lineValue, funcValue, messageValue}, + {false, false, false, false}); + barrier(loc, rewriter, triton::gpu::AddrSpace::All); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.exit", TypeRange{}, + {}); +} + +int TargetInfo::getSharedAddressSpace() const { return 3; } + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + if (isa(addressSpace)) + return 3; + return 0; +} + +bool TargetInfo::supportVectorizedAtomics() const { return false; } diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TensorPtrOpsToLLVM.cpp new file mode 100644 index 0000000000..2d940034a0 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -0,0 +1,60 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct MakeTensorPtrOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector elems; + elems.append(adaptor.getOffsets().begin(), adaptor.getOffsets().end()); + elems.append(adaptor.getShape().begin(), adaptor.getShape().end()); + elems.append(adaptor.getStrides().begin(), adaptor.getStrides().end()); + elems.push_back(adaptor.getBase()); + + Value packed = ::mlir::packLLElements(op.getLoc(), getTypeConverter(), + elems, rewriter, op.getType()); + rewriter.replaceOp(op, packed); + return success(); + } +}; + +struct AdvanceOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + + auto elems = ::mlir::unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = adaptor.getOffsets(); + + for (auto [i, offset] : llvm::enumerate(offsets)) { + elems[i] = b.add(offset, elems[i]); + } + + Value packed = ::mlir::packLLElements(loc, getTypeConverter(), elems, + rewriter, op.getPtr().getType()); + rewriter.replaceOp(op, packed); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateTensorPtrOpsToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, + benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ThreadIdOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ThreadIdOpToLLVM.cpp new file mode 100644 index 0000000000..d2c9c452d6 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/ThreadIdOpToLLVM.cpp @@ -0,0 +1,47 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class ThreadIdOpPattern : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mlir::gpu::ThreadIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef intrinsic; + switch (op.getDimension()) { + case mlir::gpu::Dimension::x: + intrinsic = "llvm.musa.read.ptx.sreg.tid.x"; + break; + case mlir::gpu::Dimension::y: + intrinsic = "llvm.musa.read.ptx.sreg.tid.y"; + break; + case mlir::gpu::Dimension::z: + intrinsic = "llvm.musa.read.ptx.sreg.tid.z"; + break; + } + + Type ty = getTypeConverter()->convertType(op.getType()); + auto call = LLVM::createLLVMIntrinsicCallOp(rewriter, op.getLoc(), + intrinsic, ty, {}); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateThreadIdOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp new file mode 100644 index 0000000000..c8612ccabd --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp @@ -0,0 +1,504 @@ +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSAGPUToLLVM/Allocation.h" +#include "TritonMUSAGPUToLLVM/Passes.h" +#include "TritonMUSAGPUToLLVM/TargetInfo.h" +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToMTGPU/GPUToMTGPUPass.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include +#include + +#include "PatternTritonGPUOpToLLVM.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONMUSAGPUTOLLVM +#include "TritonMUSAGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using mlir::triton::MUSA::TargetInfo; +namespace { + +enum class InplaceLoadDataKind { + Unsupported, + Integer, + Float, +}; + +static StringRef getInplaceLoadIntrinsicName(InplaceLoadDataKind dataKind) { + switch (dataKind) { + case InplaceLoadDataKind::Integer: + return "llvm.musa.lsu.ld.cache.hint.i"; + case InplaceLoadDataKind::Float: + return "llvm.musa.lsu.ld.cache.hint.f"; + default: + return {}; + } +} + +static SmallVector +buildInplaceLoadCacheHintOperands(Value ptr, Location loc, + PatternRewriter &rewriter) { + Type i32Ty = rewriter.getI32Type(); + Value innerPersist = LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(0)); + Value outerPersist = LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(2)); + Value chrnt = LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(1)); + Value slc = LLVM::ConstantOp::create(rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(1)); + return {ptr, innerPersist, outerPersist, chrnt, slc}; +} + +static std::optional getTypeBitWidth(Type type) { + if (auto vecTy = dyn_cast(type)) { + Type elemTy = vecTy.getElementType(); + if (elemTy.isIntOrFloat()) + return vecTy.getNumElements() * elemTy.getIntOrFloatBitWidth(); + return std::nullopt; + } + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + return std::nullopt; +} + +static InplaceLoadDataKind getInplaceLoadDataKind(Type type) { + Type elemTy = type; + if (auto vecTy = dyn_cast(type)) + elemTy = vecTy.getElementType(); + + if (elemTy.isIntOrIndex()) + return InplaceLoadDataKind::Integer; + if (isa(elemTy)) + return InplaceLoadDataKind::Float; + return InplaceLoadDataKind::Unsupported; +} + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + + addLegalOp(); + addLegalOp(); + addLegalOp(); + addLegalOp(); + } +}; + +class PredicatedCallOpConversion : public RewritePattern { +public: + explicit PredicatedCallOpConversion(MLIRContext *context, + int32_t computeCapability) + : RewritePattern(LLVM::CallOp::getOperationName(), 1, context), + computeCapability(computeCapability) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto callOp = dyn_cast(op); + if (!callOp || !callOp.getCallee()) + return failure(); + auto callee = callOp.getCallee().value(); + if (callee.contains(LLVM::MUSA::Predicated_Load)) + return rewritePredicatedLoad(callOp, rewriter, /*useCacheHint=*/false); + if (callee.contains(LLVM::MUSA::Predicated_InplaceLoad)) + return rewritePredicatedLoad(callOp, rewriter, + /*useCacheHint=*/computeCapability == 31); + if (callee.contains(LLVM::MUSA::Predicated_Store)) + return rewritePredicatedStore(callOp, rewriter); + return failure(); + } + +private: + static LogicalResult rewritePredicatedStore(LLVM::CallOp callOp, + PatternRewriter &rewriter) { + Location loc = callOp.getLoc(); + auto operands = callOp.getOperands(); + if (operands.size() != 3) + return failure(); + Value ptr = operands[0]; + Value val = operands[1]; + Value pred = operands[2]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterStore = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterStore); + + rewriter.setInsertionPointToEnd(currentBlock); + LLVM::CondBrOp::create(rewriter, loc, pred, trueBlock, afterStore); + + rewriter.setInsertionPointToStart(trueBlock); + LLVM::StoreOp::create(rewriter, loc, val, ptr); + LLVM::BrOp::create(rewriter, loc, afterStore); + + rewriter.setInsertionPointToStart(afterStore); + rewriter.eraseOp(callOp); + return success(); + } + + static LogicalResult rewritePredicatedLoad(LLVM::CallOp callOp, + PatternRewriter &rewriter, + bool useCacheHint) { + Location loc = callOp.getLoc(); + auto operands = callOp.getOperands(); + if (operands.size() != 3) + return failure(); + Value ptr = operands[0]; + Value pred = operands[1]; + Value falseVal = operands[2]; + + Type elemTy = callOp.getResult().getType(); + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + afterLoad->addArgument(elemTy, loc); + Block *trueBlock = rewriter.createBlock(afterLoad); + Block *falseBlock = + rewriter.splitBlock(trueBlock, rewriter.getInsertionPoint()); + + rewriter.setInsertionPointToEnd(currentBlock); + LLVM::CondBrOp::create(rewriter, loc, pred, trueBlock, falseBlock); + + rewriter.setInsertionPointToStart(trueBlock); + Value loaded; + std::optional typeBits = getTypeBitWidth(elemTy); + InplaceLoadDataKind dataKind = getInplaceLoadDataKind(elemTy); + StringRef intrinsic = getInplaceLoadIntrinsicName(dataKind); + if (useCacheHint && typeBits && *typeBits == 128 && !intrinsic.empty()) { + SmallVector operands = + buildInplaceLoadCacheHintOperands(ptr, loc, rewriter); + loaded = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, + TypeRange{elemTy}, operands) + .getResult(0); + } else { + loaded = LLVM::LoadOp::create(rewriter, loc, elemTy, ptr); + } + LLVM::BrOp::create(rewriter, loc, loaded, afterLoad); + + rewriter.setInsertionPointToStart(falseBlock); + LLVM::BrOp::create(rewriter, loc, falseVal, afterLoad); + + rewriter.setInsertionPointToStart(afterLoad); + rewriter.replaceOp(callOp, afterLoad->getArgument(0)); + return success(); + } + + int32_t computeCapability; +}; + +class CancelRedundantBFloatRoundTripPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallIntrinsicOp call, + PatternRewriter &rewriter) const override { + if (call.getIntrin() != "llvm.musa.bfloat162float") + return failure(); + if (call.getArgs().size() != 1) + return failure(); + + auto producer = call.getArgs()[0].getDefiningOp(); + if (!producer || producer.getIntrin() != "llvm.musa.float2bfloat16") + return failure(); + if (producer.getArgs().size() != 1 || !producer->hasOneUse()) + return failure(); + + rewriter.replaceOp(call, producer.getArgs()[0]); + rewriter.eraseOp(producer); + return success(); + } +}; + +std::optional +inferElemBytesFromMemDesc(triton::gpu::MemDescType type) { + int bitWidth = type.getElementTypeBitWidth(); + if (bitWidth <= 0) + return std::nullopt; + return static_cast((bitWidth + 7) / 8); +} + +bool inferRowMajorFromMemDesc(triton::gpu::MemDescType type) { + auto order = triton::gpu::getOrder(type); + if (order.empty()) + return true; + return order.front() + 1 == type.getShape().size(); +} + +unsigned getSqmmaSwizzleAlignment(ModuleOp mod) { + // Shared memory alignment should satisfy all SQMMA swizzle requirements. + unsigned maxAlignment = 256; + mod.walk([&](triton::gpu::LocalAllocOp localAllocOp) { + auto maybeOpIdx = triton::musa::getSqmmaOpIdx(localAllocOp.getOperation()); + if (!maybeOpIdx) + return; + + auto memDescTy = cast(localAllocOp.getType()); + auto maybeElemBytes = + triton::musa::getSqmmaElemBytes(localAllocOp.getOperation()); + if (!maybeElemBytes) + maybeElemBytes = inferElemBytesFromMemDesc(memDescTy); + if (!maybeElemBytes || *maybeElemBytes <= 0) + return; + + bool isRowMajor = triton::musa::getSqmmaRowMajor( + localAllocOp.getOperation(), inferRowMajorFromMemDesc(memDescTy)); + + int64_t opIdx = *maybeOpIdx; + bool isMNMajor = + ((opIdx == 0) && !isRowMajor) || ((opIdx == 1) && isRowMajor); + unsigned sg = 16; + if (*maybeElemBytes == 2) + sg = isMNMajor ? 32 : 16; + else if (*maybeElemBytes == 4) + sg = isMNMajor ? 64 : 16; + + unsigned alignment = 256 * (256 / sg); + maxAlignment = std::max(maxAlignment, alignment); + }); + return maxAlignment; +} + +struct ConvertTritonMUSAGPUToLLVM + : public triton::impl::ConvertTritonMUSAGPUToLLVMBase< + ConvertTritonMUSAGPUToLLVM> { + using ConvertTritonMUSAGPUToLLVMBase::ConvertTritonMUSAGPUToLLVMBase; + + ConvertTritonMUSAGPUToLLVM() = default; + ConvertTritonMUSAGPUToLLVM(int32_t computeCapability) + : ConvertTritonMUSAGPUToLLVMBase({computeCapability}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + TargetInfo targetInfo(computeCapability); + + size_t numOps = 0; + size_t numConvertLayoutOps = 0; + mod.walk([&](Operation *op) { + numOps++; + if (isa(op)) + numConvertLayoutOps++; + }); + bool isLargeKernel = numOps > 500 || numConvertLayoutOps > 10; + + auto groupedTMELoadWalk = mod.walk( + [&](triton::musa::AsyncTMECopyGlobalToLocalOp op) -> WalkResult { + if (failed( + triton::musa::recoverAndVerifyGroupedTMELoadConsumerContract( + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (groupedTMELoadWalk.wasInterrupted()) + return signalPassFailure(); + + ModuleAllocation allocation( + mod, mlir::triton::musa_gpu::getMusaAllocationAnalysisScratchSizeFn( + targetInfo)); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); + typeConverter.addConversion( + [&](triton::mtgpu::SqmmaAccumulatorType type) -> std::optional { + auto info = LLVM::MUSA::getSqmmaAccumulatorCarrierInfo(type); + if (failed(info)) + return std::nullopt; + return info->carrierType; + }); + + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::populateFuncOpConversionPattern( + typeConverter, funcPatterns, targetInfo, patternBenefitDefault); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + initSharedMemory(typeConverter, targetInfo); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + + RewritePatternSet patterns(context); + int benefit = patternBenefitPrioritizeOverLLVMConversions; + mlir::triton::MUSA::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + mlir::triton::MUSA::populateDotOpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::MUSA::populateMUSAOpsToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::MUSA::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, computeCapability, + targetInfo, benefit); + mlir::triton::MUSA::populateLoadStoreOpToLLVMPatterns( + typeConverter, targetInfo, computeCapability, patterns, + axisInfoAnalysis, benefit); + mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateGatherOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::MUSA::populateBarrierOpToLLVMPatterns(typeConverter, patterns, + benefit, targetInfo); + mlir::triton::MUSA::populateTensorPtrOpsToLLVMPatterns(typeConverter, + patterns, benefit); + mlir::triton::populateHistogramOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::MUSA::populateSPMDOpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::MUSA::populateThreadIdOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::MUSA::populateWarpIdOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::arith::populateCeilFloorDivExpandOpsPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + // Native lowering patterns. + mlir::populateGpuToMTGPUConversionPatterns(typeConverter, patterns); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::populateInstrumentationToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + + TritonLLVMConversionTarget convTarget(*context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + + RewritePatternSet cleanupPatterns(context); + cleanupPatterns.add(context); + + GreedyRewriteConfig cleanupConfig; + if (isLargeKernel) { + cleanupConfig.setMaxIterations(3) + .setMaxNumRewrites(1000) + .setStrictness(GreedyRewriteStrictness::ExistingOps) + .setUseTopDownTraversal(true) + .setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); + } + + if (failed(applyPatternsGreedily(mod, std::move(cleanupPatterns), + cleanupConfig))) + return signalPassFailure(); + + RewritePatternSet predicatedPatterns(context); + predicatedPatterns.add(context, + computeCapability); + + GreedyRewriteConfig predicatedConfig; + if (isLargeKernel) { + predicatedConfig.setMaxIterations(3) + .setMaxNumRewrites(1000) + .setStrictness(GreedyRewriteStrictness::ExistingOps) + .setUseTopDownTraversal(true) + .setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); + } + + if (failed(applyPatternsGreedily(mod, std::move(predicatedPatterns), + predicatedConfig))) + return signalPassFailure(); + + TritonLLVMFunctionConversionTarget cfTarget(*context); + cfTarget.markUnknownOpDynamicallyLegal([&](Operation *op) { + return op->getDialect() != + context->getLoadedDialect(); + }); + RewritePatternSet cfPatterns(context); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + cfPatterns); + if (failed(applyPartialConversion(mod, cfTarget, std::move(cfPatterns)))) + return signalPassFailure(); + + fixUpLoopAnnotation(mod); + } + +private: + void initSharedMemory(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + unsigned alignment = getSqmmaSwizzleAlignment(mod); + LLVM::GlobalOp::create( + b, loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, + "global_smem", /*value=*/Attribute(), /*alignment=*/alignment, + static_cast(targetInfo.getSharedAddressSpace())); + } +}; + +} // namespace + +namespace mlir::triton { + +std::unique_ptr> +createConvertTritonMUSAGPUToLLVMPass() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertTritonMUSAGPUToLLVMPass(int32_t computeCapability) { + return std::make_unique(computeCapability); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/Utility.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/Utility.cpp new file mode 100644 index 0000000000..d0b53a7e35 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/Utility.cpp @@ -0,0 +1,437 @@ +#include "TritonMUSAGPUToLLVM/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::getFunctionType; +using namespace mlir::triton; + +namespace { + +static RankedTensorType getSqmmaCarrierTensorType(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy; + if (auto carrierTy = dyn_cast(type)) + return carrierTy.getAccumulatorType(); + return RankedTensorType(); +} + +static FailureOr +buildSqmmaAccumulatorCarrierInfo(Type type) { + auto tensorTy = getSqmmaCarrierTensorType(type); + auto mmaEnc = + tensorTy + ? dyn_cast(tensorTy.getEncoding()) + : triton::gpu::MUSASqmmaEncodingAttr(); + if (!tensorTy || !mmaEnc || !mmaEnc.isPH1()) + return failure(); + + auto instrShape = mmaEnc.getInstrShape(); + auto warpsPerCTA = mmaEnc.getWarpsPerCTA(); + if (instrShape.size() != 3 || warpsPerCTA.size() < 2 || warpsPerCTA[0] < 4) + return failure(); + + auto shapePerCTA = triton::gpu::getShapePerCTA(tensorTy); + auto ceilDiv = [](unsigned x, unsigned y) { + return y == 0 ? 0 : (x + y - 1) / y; + }; + + unsigned instM = instrShape[0]; + unsigned instN = instrShape[1]; + unsigned squadsM = warpsPerCTA[0] / 4; + unsigned squadsN = warpsPerCTA[1]; + unsigned tileM = instM * squadsM; + unsigned tileN = instN * squadsN; + unsigned numRepM = std::max(1u, ceilDiv(shapePerCTA[0], tileM)); + unsigned numRepN = std::max(1u, ceilDiv(shapePerCTA[1], tileN)); + unsigned fragmentCount = numRepM * numRepN; + unsigned totalAccElems = mmaEnc.getTotalElemsPerThread(tensorTy.getShape()); + if (fragmentCount == 0 || totalAccElems == 0 || + (totalAccElems % fragmentCount) != 0) { + return failure(); + } + + MLIRContext *ctx = type.getContext(); + unsigned fragmentElems = totalAccElems / fragmentCount; + Type fragmentType = VectorType::get({static_cast(fragmentElems)}, + tensorTy.getElementType()); + Type carrierType = fragmentType; + if (fragmentCount > 1) { + SmallVector fields(fragmentCount, fragmentType); + carrierType = LLVM::LLVMStructType::getLiteral(ctx, fields); + } + + return LLVM::MUSA::SqmmaAccumulatorCarrierInfo{ + tensorTy, fragmentCount, fragmentElems, fragmentType, carrierType}; +} + +std::string getTypeString(mlir::Type ty) { + std::string str; + llvm::raw_string_ostream rso(str); + ty.print(rso); + rso.flush(); + return str; +} + +std::string mangleFunc(llvm::StringRef name, mlir::Type type) { + auto funcType = llvm::dyn_cast(type); + assert(funcType && "Expected LLVM function type"); + std::string mangled = name.str(); + mangled.push_back('_'); + mangled += getTypeString(funcType.getReturnType()); + mangled.push_back('_'); + for (auto param : funcType.getParams()) { + mangled += getTypeString(param); + mangled.push_back('_'); + } + return mangled; +} + +llvm::StringRef getShuffleIntrinsicName(llvm::StringRef kind) { + if (kind == "xor") + return "llvm.musa.shfl.xor.sync.i32"; + if (kind == "up") + return "llvm.musa.shfl.up.sync.i32"; + if (kind == "down") + return "llvm.musa.shfl.down.sync.i32"; + return "llvm.musa.shfl.idx.sync.i32"; +} + +mlir::Value shuffleCommon(mlir::Location loc, mlir::RewriterBase &rewriter, + mlir::Value value, mlir::Value offset, + llvm::StringRef kind, int widthInt) { + using namespace mlir; + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto valueTy = value.getType(); + + if (auto ptrTy = dyn_cast(valueTy)) { + Value raw = b.ptrtoint(i64_ty, value); + Value shuffled = shuffleCommon(loc, rewriter, raw, offset, kind, widthInt); + return b.inttoptr(ptrTy, shuffled); + } + + unsigned bits = valueTy.getIntOrFloatBitWidth(); + + if (bits == 64) { + auto i64Ty = rewriter.getI64Type(); + Value raw = valueTy.isInteger(64) ? value : b.bitcast(value, i64Ty); + Value lo = b.trunc(i32_ty, raw); + Value hi = b.trunc(i32_ty, b.lshr(i64Ty, raw, b.int_val(64, 32))); + lo = shuffleCommon(loc, rewriter, lo, offset, kind, widthInt); + hi = shuffleCommon(loc, rewriter, hi, offset, kind, widthInt); + Value packedLo = b.zext(i64Ty, lo); + Value packedHi = b.shl(i64Ty, b.zext(i64Ty, hi), b.int_val(64, 32)); + Value packed = b.or_(packedLo, packedHi); + return valueTy.isInteger(64) ? packed : b.bitcast(packed, valueTy); + } + + Value val = value; + if (!valueTy.isInteger(32)) { + val = b.bitcast(val, int_ty(bits)); + if (bits < 32) + val = b.zext(i32_ty, val); + } + + Value maskAndClamp; + if (kind == "up") { + maskAndClamp = b.i32_val(0); + } else { + Value width = b.i32_val(widthInt); + Value clamp = b.sub(width, b.i32_val(1)); + Value segMask = b.sub(b.i32_val(128), width); + segMask = b.shl(segMask, b.i32_val(7)); + maskAndClamp = b.or_(segMask, clamp); + } + + auto nullPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 5); + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, nullPtrTy); + auto intrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, getShuffleIntrinsicName(kind), i32_ty, + {val, offset, maskAndClamp, nullPtr}); + Value result = intrinsic.getResult(0); + + if (!valueTy.isInteger(32)) { + if (bits < 32) + result = b.trunc(int_ty(bits), result); + result = b.bitcast(result, valueTy); + } + return result; +} + +} // namespace + +namespace mlir { +namespace LLVM { +namespace MUSA { + +FailureOr +getSqmmaAccumulatorCarrierInfo(Type type) { + return buildSqmmaAccumulatorCarrierInfo(type); +} + +SmallVector unpackSqmmaAccumulatorCarrier(Location loc, Value carrier, + Type type, + RewriterBase &rewriter) { + auto info = buildSqmmaAccumulatorCarrierInfo(type); + assert(succeeded(info) && "expected valid SQMMA accumulator carrier type"); + if (info->fragmentCount == 1) + return {carrier}; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector fragments; + fragments.reserve(info->fragmentCount); + for (unsigned i = 0; i < info->fragmentCount; ++i) { + fragments.push_back(b.extract_val(info->fragmentType, carrier, i)); + } + return fragments; +} + +Value packSqmmaAccumulatorCarrier(Location loc, ValueRange fragments, Type type, + RewriterBase &rewriter) { + auto info = buildSqmmaAccumulatorCarrierInfo(type); + assert(succeeded(info) && "expected valid SQMMA accumulator carrier type"); + assert(fragments.size() == info->fragmentCount && + "fragment count mismatch when packing SQMMA carrier"); + if (info->fragmentCount == 1) + return fragments.front(); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value packed = LLVM::UndefOp::create(rewriter, loc, info->carrierType); + for (unsigned i = 0; i < info->fragmentCount; ++i) { + packed = b.insert_val(info->carrierType, packed, fragments[i], i); + } + return packed; +} + +Value carrierFragmentToMathVec(Location loc, Value fragment, Type type, + RewriterBase &rewriter) { + auto info = buildSqmmaAccumulatorCarrierInfo(type); + assert(succeeded(info) && "expected valid SQMMA accumulator carrier type"); + Type elemTy = info->tensorType.getElementType(); + Type mathVecTy = + VectorType::get({static_cast(info->fragmentElems)}, elemTy); + if (fragment.getType() == mathVecTy) + return fragment; + return TritonLLVMOpBuilder(loc, rewriter).bitcast(fragment, mathVecTy); +} + +Value mathVecToCarrierFragment(Location loc, Value mathVec, Type type, + RewriterBase &rewriter) { + auto info = buildSqmmaAccumulatorCarrierInfo(type); + assert(succeeded(info) && "expected valid SQMMA accumulator carrier type"); + if (mathVec.getType() == info->fragmentType) + return mathVec; + return TritonLLVMOpBuilder(loc, rewriter) + .bitcast(mathVec, info->fragmentType); +} + +Value packSqmmaAccumulatorCarrierFromTensor(Location loc, Value tensorValue, + RankedTensorType tensorType, + const LLVMTypeConverter *converter, + RewriterBase &rewriter) { + (void)converter; + auto info = buildSqmmaAccumulatorCarrierInfo(tensorType); + assert(succeeded(info) && "expected valid SQMMA accumulator tensor"); + SmallVector elements = + ::mlir::unpackLLElements(loc, tensorValue, rewriter); + assert(elements.size() == info->fragmentCount * info->fragmentElems && + "unexpected SQMMA accumulator element count"); + + SmallVector fragments; + fragments.reserve(info->fragmentCount); + for (unsigned i = 0; i < info->fragmentCount; ++i) { + ArrayRef slice(elements); + Value mathVec = packLLVector( + loc, slice.slice(i * info->fragmentElems, info->fragmentElems), + rewriter); + fragments.push_back( + mathVecToCarrierFragment(loc, mathVec, tensorType, rewriter)); + } + return packSqmmaAccumulatorCarrier(loc, fragments, tensorType, rewriter); +} + +Value unpackSqmmaAccumulatorCarrierToTensor(Location loc, Value carrier, + RankedTensorType tensorType, + const LLVMTypeConverter *converter, + RewriterBase &rewriter) { + SmallVector fragments = + unpackSqmmaAccumulatorCarrier(loc, carrier, tensorType, rewriter); + SmallVector elements; + auto info = buildSqmmaAccumulatorCarrierInfo(tensorType); + assert(succeeded(info) && "expected valid SQMMA accumulator tensor"); + elements.reserve(info->fragmentCount * info->fragmentElems); + for (Value fragment : fragments) { + Value mathVec = + carrierFragmentToMathVec(loc, fragment, tensorType, rewriter); + SmallVector mathElems = unpackLLVector(loc, mathVec, rewriter); + elements.append(mathElems.begin(), mathElems.end()); + } + return ::mlir::packLLElements(loc, converter, elements, rewriter, tensorType); +} + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), "xor", width); +} + +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), "up", width); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), "idx", width); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + unsigned width) { + return shuffleCommon(loc, rewriter, val, i, "idx", width); +} + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp /*moduleOp*/, + triton::ProgramIDDim axis) { + StringRef intrinsic; + switch (axis) { + case triton::ProgramIDDim::X: + intrinsic = "llvm.musa.read.ptx.sreg.ctaid.x"; + break; + case triton::ProgramIDDim::Y: + intrinsic = "llvm.musa.read.ptx.sreg.ctaid.y"; + break; + case triton::ProgramIDDim::Z: + intrinsic = "llvm.musa.read.ptx.sreg.ctaid.z"; + break; + default: + intrinsic = "llvm.musa.read.ptx.sreg.ctaid.x"; + break; + } + + auto call = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i32_ty, {}); + return call.getResult(0); +} + +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal) { + Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto funcName = mangleFunc(Predicated_Load, funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + auto loadVal = LLVM::CallOp::create(rewriter, loc, funcOp, + ValueRange({ptr, pred, falseVal})) + .getResult(); + return loadVal; +} + +Value llInplaceLoad(RewriterBase &rewriter, Location loc, Value ptr, + Type elemTy, Value pred, Value falseVal) { + Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto funcName = mangleFunc(Predicated_InplaceLoad, funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + auto loadVal = LLVM::CallOp::create(rewriter, loc, funcOp, + ValueRange({ptr, pred, falseVal})) + .getResult(); + return loadVal; +} + +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) { + Type funcType = getFunctionType(void_ty(rewriter.getContext()), + ValueRange({ptr, val, pred})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto funcName = mangleFunc(Predicated_Store, funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + LLVM::CallOp::create(rewriter, loc, funcOp, ValueRange({ptr, val, pred})); +} + +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask) { + auto bld = TritonLLVMOpBuilder(loc, rewriter); + Type valueTy = a.getType(); + assert(valueTy == b.getType() && "permute operands must have the same type"); + + if (auto ptrTy = dyn_cast(valueTy)) { + Value aI64 = bld.ptrtoint(i64_ty, a); + Value bI64 = bld.ptrtoint(i64_ty, b); + Value resultI64 = permute(loc, rewriter, aI64, bI64, mask); + return bld.inttoptr(ptrTy, resultI64); + } + + unsigned bits = valueTy.getIntOrFloatBitWidth(); + if (bits == 64) { + Value aI64 = valueTy.isInteger(64) ? a : bld.bitcast(a, i64_ty); + Value bI64 = valueTy.isInteger(64) ? b : bld.bitcast(b, i64_ty); + + Value aLo = bld.trunc(i32_ty, aI64); + Value aHi = bld.trunc(i32_ty, bld.lshr(i64_ty, aI64, bld.i64_val(32))); + Value bLo = bld.trunc(i32_ty, bI64); + Value bHi = bld.trunc(i32_ty, bld.lshr(i64_ty, bI64, bld.i64_val(32))); + + Value outLo = permute(loc, rewriter, aLo, bLo, mask); + Value outHi = permute(loc, rewriter, aHi, bHi, mask); + Value packedLo = bld.zext(i64_ty, outLo); + Value packedHi = bld.shl(i64_ty, bld.zext(i64_ty, outHi), bld.i64_val(32)); + Value packed = bld.or_(packedLo, packedHi); + return valueTy.isInteger(64) ? packed : bld.bitcast(packed, valueTy); + } + + Type rawTy = valueTy.isIntOrIndex() ? valueTy : int_ty(bits); + Value aI32 = a; + Value bI32 = b; + if (!valueTy.isIntOrIndex()) { + aI32 = bld.bitcast(a, rawTy); + bI32 = bld.bitcast(b, rawTy); + } + if (bits < 32) { + aI32 = bld.zext(i32_ty, aI32); + bI32 = bld.zext(i32_ty, bI32); + } else if (bits != 32) { + llvm_unreachable("permute only supports scalar values up to 64 bits"); + } else if (!valueTy.isInteger(32)) { + aI32 = bld.bitcast(aI32, i32_ty); + bI32 = bld.bitcast(bI32, i32_ty); + } + + auto call = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.musa.prmt", + i32_ty, {aI32, bI32, mask}); + Value result = call.getResult(0); + if (bits < 32) + result = bld.trunc(rawTy, result); + if (!valueTy.isIntOrIndex()) + result = bld.bitcast(result, valueTy); + return result; +} + +Value createElectPredicate(Location loc, PatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value laneId = + LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.read.ptx.sreg.laneid", i32_ty, {}) + .getResult(0); + Value firstLane = + LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.vote.firstactivelane", i32_ty, {}) + .getResult(0); + return b.icmp_eq(laneId, firstLane); +} + +LLVM::LLVMFuncOp getLibdeviceFuncCall(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type retType, + ValueRange ins) { + Type funcType = getFunctionType(retType, ins); + return appendOrGetExternFuncOp(rewriter, op, funcName, funcType); +} + +} // namespace MUSA +} // namespace LLVM +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/WarpIdOpToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/WarpIdOpToLLVM.cpp new file mode 100644 index 0000000000..e5f7853c88 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/WarpIdOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class WarpIdOpPattern + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + mlir::triton::gpu::WarpIdOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mlir::triton::gpu::WarpIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // This is runtime-constant for a program instance; move it to function + // entry unless we are inside a warp-specialized partition. + std::optional startWarpId = getWarpGroupStartWarpId(op->getBlock()); + if (!startWarpId) { + auto funcOp = op->getParentOfType(); + rewriter.setInsertionPoint( + &funcOp.getFunctionBody().getBlocks().front().front()); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value tid = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.musa.read.ptx.sreg.tid.x", i32_ty, {}) + .getResult(0); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + Value warpId = b.udiv(tid, b.i32_val(threadsPerWarp)); + + if (startWarpId) + warpId = b.sub(warpId, b.i32_val(*startWarpId)); + + rewriter.replaceOp(op, warpId); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateWarpIdOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/AccelerateMUSAMatmul.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/AccelerateMUSAMatmul.cpp new file mode 100644 index 0000000000..e6b0a356e0 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/AccelerateMUSAMatmul.cpp @@ -0,0 +1,977 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MMAContractUtils.h" +#include "TritonMUSACommon/MemDescUtils.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include +#include +#include +#include +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +inline constexpr llvm::StringLiteral kDisableGenericDotPipelineAttr = + "tt.disable_generic_dot_pipeline"; + +static int getMusaComputeCapability(ModuleOp mod) { + StringAttr targetAttr = mod->getAttrOfType(ttg::AttrTargetName); + if (!targetAttr) + return -1; + StringRef ref = targetAttr.strref(); + if (!ref.starts_with("musa:")) + return -1; + StringRef arch = ref.drop_front(5); + if (arch.starts_with("ph1")) + return 31; + int computeCapability = -1; + if (arch.getAsInteger(10, computeCapability)) + return -1; + return computeCapability; +} + +static std::optional toSqmmaEltType(Type elemTy) { + if (elemTy.isF16()) + return triton::musa::SQMMAEltType::f16; + if (elemTy.isBF16()) + return triton::musa::SQMMAEltType::bf16; + if (elemTy.isF32()) + return triton::musa::SQMMAEltType::f32; + if (elemTy.isInteger(32)) + return triton::musa::SQMMAEltType::s32; + if (elemTy.isInteger(8)) + return triton::musa::SQMMAEltType::s8; + if (llvm::isa(elemTy)) + return triton::musa::SQMMAEltType::e4m3; + if (llvm::isa(elemTy)) + return triton::musa::SQMMAEltType::e5m2; + return std::nullopt; +} + +static std::optional +toSqmmaOperandEltType(Type elemTy, bool allowTF32) { + if (elemTy.isF32() && allowTF32) + return triton::musa::SQMMAEltType::tf32; + return toSqmmaEltType(elemTy); +} + +static triton::musa::SQMMALayout inferSqmmaLayout(Value v) { + if (auto tensorTy = dyn_cast(v.getType())) { + auto order = ttg::getOrderForMemory(tensorTy); + bool isRowMajor = !order.empty() && order.front() + 1 == tensorTy.getRank(); + return isRowMajor ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; + } + if (auto memDescTy = dyn_cast(v.getType())) { + auto order = ttg::getOrder(memDescTy); + bool isRowMajor = + !order.empty() && order.front() + 1 == memDescTy.getRank(); + return isRowMajor ? triton::musa::SQMMALayout::row + : triton::musa::SQMMALayout::col; + } + return triton::musa::SQMMALayout::row; +} + +static bool isSupportedWmmaOperandType(Type elemTy, bool allowTF32) { + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(8) || + tt::type::isFloat8(elemTy)) + return true; + return elemTy.isF32() && allowTF32; +} + +static SmallVector> +getWmmaCandidateInstrShapes(Type elemTy, bool allowTF32) { + if (elemTy.isF32() && allowTF32) + return {{16, 8, 4}, {16, 8, 8}, {16, 16, 16}}; + if (elemTy.isF16() || elemTy.isBF16()) { + return { + {8, 16, 16}, {16, 8, 8}, {16, 8, 16}, {16, 16, 16}, {16, 16, 32}, + }; + } + return { + {8, 16, 16}, {16, 8, 16}, {16, 16, 16}, {16, 16, 32}, {16, 16, 64}, + }; +} + +struct SelectedConfig { + SmallVector instrShape; + SmallVector warpsPerCTA; +}; + +static bool isKnownBrokenSqmmaConfig(Type elemTy, bool allowTF32, + ArrayRef instrShape) { + auto eltTypeA = toSqmmaOperandEltType(elemTy, allowTF32); + if (!eltTypeA || instrShape.size() != 3) + return false; + + triton::musa::SQMMAEltType eltTypeC = elemTy.isInteger(8) + ? triton::musa::SQMMAEltType::s32 + : triton::musa::SQMMAEltType::f32; + return !triton::musa::isSupportedSqmma(*eltTypeA, *eltTypeA, eltTypeC, + instrShape[0], instrShape[1], + instrShape[2]); +} + +static SmallVector +selectWarpsPerCTAForPH1(unsigned m, unsigned n, unsigned numWarps, + ArrayRef instrShape) { + assert(instrShape.size() == 3 && "Unexpected instrShape rank"); + SmallVector ret{1, 1}; + unsigned maxWarpsM = std::max(1u, m / instrShape[0]); + while (ret[0] * ret[1] < numWarps) { + bool growM = + (m / instrShape[0] / ret[0]) >= (n / (instrShape[1] * 2) / ret[1]); + if (growM) { + if (ret[0] < maxWarpsM) + ret[0] *= 2; + else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } + return ret; +} + +static bool shouldUseSqmmaCOperand(Type aElemTy, Type dElemTy, unsigned m, + unsigned n, uint32_t maxNumImpreciseAcc, + const SelectedConfig &config) { + if (!tt::type::isFloat8(aElemTy) || !dElemTy.isF32() || + maxNumImpreciseAcc != 0) + return true; + + unsigned instM = config.instrShape[0]; + unsigned instN = config.instrShape[1]; + unsigned squadsM = std::max(1u, config.warpsPerCTA[0] / 4); + unsigned squadsN = std::max(1u, config.warpsPerCTA[1]); + unsigned tileM = instM * squadsM; + unsigned tileN = instN * squadsN; + auto ceilDiv = [](unsigned x, unsigned y) { return (x + y - 1) / y; }; + unsigned numRepM = ceilDiv(m, tileM); + unsigned numRepN = ceilDiv(n, tileN); + bool keepSoftwareAccumFamily = + numRepM == 1 && m <= 64 && n >= 32 && !(m >= 32 && n >= 256); + return !keepSoftwareAccumFamily; +} + +struct SqmmaAccumulationContract { + bool useCOperand = true; + triton::musa::SQMMAAccumulationMode mode = + triton::musa::SQMMAAccumulationMode::hardware; +}; + +static SqmmaAccumulationContract selectSqmmaAccumulationContract( + Type aElemTy, Type dElemTy, unsigned m, unsigned n, unsigned k, + bool accIsZero, uint32_t maxNumImpreciseAcc, const SelectedConfig &config) { + SqmmaAccumulationContract contract; + contract.useCOperand = + !accIsZero || shouldUseSqmmaCOperand(aElemTy, dElemTy, m, n, + maxNumImpreciseAcc, config); + if (!tt::type::isFloat8(aElemTy) || !dElemTy.isF32()) + return contract; + + unsigned instM = config.instrShape[0]; + unsigned instN = config.instrShape[1]; + unsigned instK = config.instrShape[2]; + unsigned squadsM = std::max(1u, config.warpsPerCTA[0] / 4); + unsigned squadsN = std::max(1u, config.warpsPerCTA[1]); + unsigned tileM = instM * squadsM; + unsigned tileN = instN * squadsN; + auto ceilDiv = [](unsigned x, unsigned y) { return (x + y - 1) / y; }; + unsigned numRepM = ceilDiv(m, tileM); + unsigned numRepK = std::max(1u, ceilDiv(k, instK)); + + bool softwareAccumulate = + !accIsZero && + ((!contract.useCOperand) || + (contract.useCOperand && maxNumImpreciseAcc == 0 && numRepM == 1 && + numRepK == 1 && m <= 64 && n >= 32 && !(m >= 32 && n >= 256))); + if (softwareAccumulate) { + contract.mode = triton::musa::SQMMAAccumulationMode::software; + return contract; + } + + if (maxNumImpreciseAcc > 0 && maxNumImpreciseAcc <= k) { + contract.mode = triton::musa::SQMMAAccumulationMode::partial; + return contract; + } + + return contract; +} + +static std::optional +selectWmmaConfig(unsigned m, unsigned n, unsigned k, unsigned numWarps, + Type elemTy, bool allowTF32) { + if (numWarps == 0 || (numWarps & (numWarps - 1)) != 0) + return std::nullopt; + + auto candidates = getWmmaCandidateInstrShapes(elemTy, allowTF32); + + bool found = false; + SmallVector bestInstrShape = {0, 0, 0}; + unsigned bestInstCount = 0; + + for (unsigned tileM = 1; tileM <= numWarps; tileM *= 2) { + if (numWarps % tileM) + continue; + unsigned tileN = numWarps / tileM; + if (m % tileM != 0 || n % tileN != 0) + continue; + unsigned warpM = m / tileM; + unsigned warpN = n / tileN; + + for (const auto &shape : candidates) { + unsigned instM = shape[0]; + unsigned instN = shape[1]; + unsigned instK = shape[2]; + if (warpM % instM != 0 || warpN % instN != 0 || k % instK != 0) + continue; + unsigned instCount = (warpM / instM) * (warpN / instN) * (k / instK); + if (!found || instCount < bestInstCount) { + bestInstCount = instCount; + bestInstrShape = shape; + found = true; + } + } + } + + if (!found) + return std::nullopt; + + SelectedConfig best; + best.instrShape = bestInstrShape; + best.warpsPerCTA = selectWarpsPerCTAForPH1(m, n, numWarps, best.instrShape); + return best; +} + +static bool isSupportedSqmmaOperandType(Type elemTy, bool allowTF32) { + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(8) || + tt::type::isFloat8(elemTy)) + return true; + return elemTy.isF32() && allowTF32; +} + +static SmallVector getSqmmaCandidateM(Type elemTy, bool allowTF32) { + if (elemTy.isF32() && allowTF32) + return {128, 64, 32, 16}; + return {128, 64, 32, 16}; +} + +static SmallVector getSqmmaCandidateN(Type elemTy, bool allowTF32) { + if (elemTy.isF32() && allowTF32) + return {128, 64, 32, 16}; + return {128, 64, 32, 16}; +} + +static SmallVector getSqmmaCandidateK(Type elemTy, bool allowTF32) { + if (elemTy.isF16() || elemTy.isBF16()) + return {128, 64, 32, 16}; + if (elemTy.isF32() && allowTF32) + return {32, 16, 8}; + if (tt::type::isFloat8(elemTy) || elemTy.isInteger(8)) + return {128, 64, 32}; + return {}; +} + +static bool shouldAllowSqmmaTranspose(Type elemTy) { + return elemTy.isF16() || elemTy.isBF16() || tt::type::isFloat8(elemTy); +} + +enum class SqmmaTransLoadKind { + None, // No transpose exists on the operand load chain. + PlainLoad, // A LSU-fed load chain contains a tt.trans. + Descriptor, // A descriptor-fed load chain contains a tt.trans. +}; + +static SqmmaTransLoadKind classifySqmmaTransLoad(Value v) { + Value cur = v; + while (true) { + if (auto cvtOp = cur.getDefiningOp()) { + cur = cvtOp.getSrc(); + continue; + } + if (auto bitcastOp = cur.getDefiningOp()) { + cur = bitcastOp.getSrc(); + continue; + } + auto transOp = cur.getDefiningOp(); + if (!transOp) + return SqmmaTransLoadKind::None; + + Value transSrc = transOp.getSrc(); + while (auto bitcastOp = transSrc.getDefiningOp()) + transSrc = bitcastOp.getSrc(); + return transSrc.getDefiningOp() + ? SqmmaTransLoadKind::Descriptor + : SqmmaTransLoadKind::PlainLoad; + } +} + +static Value promoteDotOperand(OpBuilder &builder, Location loc, Value operand, + Type promoteElemTy) { + auto tensorTy = dyn_cast(operand.getType()); + if (!tensorTy) + return operand; + Type srcElemTy = tensorTy.getElementType(); + if (srcElemTy == promoteElemTy) + return operand; + + auto dstTy = tensorTy.cloneWith(std::nullopt, promoteElemTy); + if (tt::type::isFloat8(srcElemTy)) + return tt::FpToFpOp::create(builder, loc, dstTy, operand); + + if (isa(srcElemTy) && isa(promoteElemTy)) + return arith::ExtFOp::create(builder, loc, dstTy, operand); + return operand; +} + +static void promoteResidualFp8DotForFma(ModuleOp mod) { + mod.walk([&](tt::DotOp dotOp) { + auto aTy = dyn_cast(dotOp.getA().getType()); + auto bTy = dyn_cast(dotOp.getB().getType()); + auto dTy = dyn_cast(dotOp.getType()); + if (!aTy || !bTy || !dTy) + return; + + Type aElemTy = aTy.getElementType(); + Type bElemTy = bTy.getElementType(); + Type dElemTy = dTy.getElementType(); + if (!tt::type::isFloat8(aElemTy) && !tt::type::isFloat8(bElemTy)) + return; + if (aElemTy == dElemTy && bElemTy == dElemTy) + return; + + OpBuilder builder(dotOp); + Location loc = dotOp.getLoc(); + Value newA = promoteDotOperand(builder, loc, dotOp.getA(), dElemTy); + Value newB = promoteDotOperand(builder, loc, dotOp.getB(), dElemTy); + dotOp.setOperand(0, newA); + dotOp.setOperand(1, newB); + }); +} + +static SmallVector getSqmmaPaddedAllocShape(RankedTensorType argType, + ArrayRef order) { + auto shape = argType.getShape(); + SmallVector allocShape(shape.begin(), shape.end()); + if (allocShape.empty() || order.empty()) + return allocShape; + + unsigned leadingDim = order.front(); + if (leadingDim >= allocShape.size()) + return allocShape; + + int elemBitWidth = argType.getElementType().getIntOrFloatBitWidth(); + int64_t elemBytes = std::max(1, (elemBitWidth + 7) / 8); + int64_t leadingBytes = allocShape[leadingDim] * elemBytes; + if (leadingBytes <= 0) + return allocShape; + + int64_t paddedLeadingBytes = leadingBytes; + if (leadingBytes <= 256) { + if (!llvm::isPowerOf2_64(static_cast(leadingBytes))) + paddedLeadingBytes = static_cast( + llvm::PowerOf2Ceil(static_cast(leadingBytes))); + } else { + paddedLeadingBytes = llvm::alignTo(leadingBytes, int64_t{256}); + } + + if (paddedLeadingBytes > leadingBytes && + (paddedLeadingBytes % elemBytes) == 0) + allocShape[leadingDim] = paddedLeadingBytes / elemBytes; + return allocShape; +} + +static Value getSharedMemorySqmmaOperand(Value v, PatternRewriter &rewriter, + int opIdx, + ttg::MUSASqmmaEncodingAttr mmaEnc, + bool allowTranspose) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + bool forceFreshRestage = false; + while (true) { + if (auto cvtOp = arg.getDefiningOp()) { + auto srcTy = dyn_cast(cvtOp.getSrc().getType()); + auto dstTy = dyn_cast(cvtOp.getType()); + if (srcTy && dstTy && isa(srcTy.getEncoding()) && + !isa(dstTy.getEncoding())) { + forceFreshRestage = true; + break; + } + arg = cvtOp.getSrc(); + continue; + } + if (auto bitcastOp = arg.getDefiningOp()) { + arg = bitcastOp.getSrc(); + continue; + } + if (auto transOp = arg.getDefiningOp()) { + if (allowTranspose || + transOp.getSrc().getDefiningOp()) + break; + arg = transOp.getSrc(); + continue; + } + break; + } + + auto argType = dyn_cast(arg.getType()); + if (!argType || !argType.getEncoding()) + return {}; + if (isa( + argType.getEncoding())) + return {}; + if (argType.getRank() != 2) + return {}; + int elemBitWidth = argType.getElementType().getIntOrFloatBitWidth(); + int elemBytes = std::max(1, (elemBitWidth + 7) / 8); + + Value descSeed = arg; + while (auto bitcastOp = descSeed.getDefiningOp()) + descSeed = bitcastOp.getSrc(); + + tt::DescriptorLoadOp descLoad; + if (auto transOp = descSeed.getDefiningOp()) + descLoad = transOp.getSrc().getDefiningOp(); + else + descLoad = descSeed.getDefiningOp(); + + SmallVector newOrder = ttg::getOrderForMemory(argType); + if (!allowTranspose) { + newOrder = SmallVector{1, 0}; + } + bool isRowMajor = + !newOrder.empty() && (newOrder.front() + 1 == argType.getRank()); + auto setSqmmaAttrs = [&](Operation *targetOp) { + triton::musa::setSqmmaAttrs(targetOp, opIdx, elemBytes, isRowMajor); + }; + auto propagateSqmmaAttrsFromLocalAlloc = [&](ttg::LocalAllocOp localAlloc) { + SmallVector pending{localAlloc.getResult()}; + llvm::SmallPtrSet visited; + while (!pending.empty()) { + Value cur = pending.pop_back_val(); + for (Operation *user : cur.getUsers()) { + if (!visited.insert(user).second) + continue; + if (auto indexOp = dyn_cast(user)) { + setSqmmaAttrs(indexOp.getOperation()); + pending.push_back(indexOp.getResult()); + continue; + } + if (auto subslice = dyn_cast(user)) { + setSqmmaAttrs(subslice.getOperation()); + pending.push_back(subslice.getResult()); + continue; + } + if (auto reinterpretOp = dyn_cast(user)) { + pending.push_back(reinterpretOp.getResult()); + continue; + } + if (auto transOp = dyn_cast(user)) { + pending.push_back(transOp.getResult()); + continue; + } + } + } + }; + auto propagateSqmmaAttrsToMemDescChain = [&](Value memDesc) { + Value cur = memDesc; + while (cur) { + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + break; + if (auto localAlloc = dyn_cast(defOp)) { + setSqmmaAttrs(localAlloc.getOperation()); + propagateSqmmaAttrsFromLocalAlloc(localAlloc); + break; + } + if (auto indexOp = dyn_cast(defOp)) { + setSqmmaAttrs(indexOp.getOperation()); + cur = indexOp.getSrc(); + continue; + } + if (auto subslice = dyn_cast(defOp)) { + setSqmmaAttrs(subslice.getOperation()); + cur = subslice.getSrc(); + continue; + } + if (auto reinterpretOp = dyn_cast(defOp)) { + cur = reinterpretOp.getSrc(); + continue; + } + if (auto transOp = dyn_cast(defOp)) { + cur = transOp.getSrc(); + continue; + } + break; + } + }; + auto cgaLayout = ttg::getCGALayout(argType.getEncoding()); + auto sharedLayout = mmaEnc.composeSharedLayoutForOperand( + cgaLayout, opIdx, argType.getShape(), newOrder, + /*kWidth=*/0, argType.getElementType().getIntOrFloatBitWidth(), + /*needTrans=*/false); + auto allocShape = getSqmmaPaddedAllocShape(argType, newOrder); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(argType.getContext()); + auto memDescTy = + ttg::MemDescType::get(argType.getShape(), argType.getElementType(), + sharedLayout, sharedMemorySpace, + /*mutableMemory=*/true, allocShape); + + if (!forceFreshRestage) { + if (auto localLoad = arg.getDefiningOp()) { + auto srcMemDescTy = + dyn_cast(localLoad.getSrc().getType()); + auto samePhysicalLayout = [&](ttg::MemDescType srcTy) { + return triton::musa::areMemDescTypesLayoutEquivalent(srcTy, memDescTy); + }; + if (srcMemDescTy && samePhysicalLayout(srcMemDescTy)) { + propagateSqmmaAttrsToMemDescChain(localLoad.getSrc()); + if (srcMemDescTy == memDescTy) + return localLoad.getSrc(); + + rewriter.setInsertionPointAfterValue(localLoad.getSrc()); + Value adapted = ttg::MemDescReinterpretOp::create( + rewriter, localLoad.getLoc(), memDescTy, localLoad.getSrc()); + setSqmmaAttrs(adapted.getDefiningOp()); + return adapted; + } + } + } + if (descLoad) { + setSqmmaAttrs(descLoad.getOperation()); + } + + Value reusedMemDesc = + forceFreshRestage + ? Value() + : triton::musa::findReusableLocalAllocForSource(arg, memDescTy); + if (reusedMemDesc) + if (auto localAlloc = reusedMemDesc.getDefiningOp()) + setSqmmaAttrs(localAlloc.getOperation()); + + if (reusedMemDesc) + return reusedMemDesc; + + rewriter.setInsertionPointAfterValue(arg); + auto localAlloc = + ttg::LocalAllocOp::create(rewriter, arg.getLoc(), memDescTy, arg); + setSqmmaAttrs(localAlloc.getOperation()); + return localAlloc.getResult(); +} + +static std::optional +selectSqmmaConfig(unsigned m, unsigned n, unsigned k, unsigned numWarps, + Type elemTy, bool allowTF32) { + if (numWarps < 4 || (numWarps % 4) != 0) + return std::nullopt; + auto sqmmaEltType = toSqmmaOperandEltType(elemTy, allowTF32); + if (!sqmmaEltType) + return std::nullopt; + + auto candidateM = getSqmmaCandidateM(elemTy, allowTF32); + auto candidateN = getSqmmaCandidateN(elemTy, allowTF32); + auto candidateK = getSqmmaCandidateK(elemTy, allowTF32); + if (candidateM.empty() || candidateN.empty() || candidateK.empty()) + return std::nullopt; + + bool found = false; + SelectedConfig best; + unsigned bestInstCount = std::numeric_limits::max(); + unsigned bestVolume = 0; + unsigned bestRepM = std::numeric_limits::max(); + unsigned bestRepN = std::numeric_limits::max(); + + for (unsigned instM : candidateM) { + if (m < instM || (m % instM) != 0) + continue; + for (unsigned instN : candidateN) { + if (n < instN || (n % instN) != 0) + continue; + if (!triton::musa::isSupportedSqmmaInstrMN(*sqmmaEltType, instM, instN)) + continue; + for (unsigned instK : candidateK) { + if (k < instK || (k % instK) != 0) + continue; + if ((instM % 4) != 0) + continue; + if (isKnownBrokenSqmmaConfig(elemTy, allowTF32, {instM, instN, instK})) + continue; + + for (unsigned warpsM = 4; warpsM <= numWarps; warpsM *= 2) { + if (numWarps % warpsM != 0) + continue; + unsigned warpsN = numWarps / warpsM; + + unsigned squadsM = warpsM / 4; + unsigned tileM = instM * squadsM; + unsigned tileN = instN * warpsN; + if ((m % tileM) != 0 || (n % tileN) != 0) + continue; + + unsigned instCount = (m / tileM) * (n / tileN) * (k / instK); + unsigned repM = m / tileM; + unsigned repN = n / tileN; + unsigned volume = instM * instN * instK; + if (!found || instCount < bestInstCount || + (instCount == bestInstCount && + (volume > bestVolume || + (volume == bestVolume && + (repM < bestRepM || + (repM == bestRepM && repN < bestRepN)))))) { + found = true; + bestInstCount = instCount; + bestVolume = volume; + bestRepM = repM; + bestRepN = repN; + best.instrShape = {instM, instN, instK}; + best.warpsPerCTA = {warpsM, warpsN}; + } + } + } + } + } + + if (!found) + return std::nullopt; + return best; +} + +class BlockedToMUSAWmma : public RewritePattern { +public: + explicit BlockedToMUSAWmma(MLIRContext *context, int computeCapability) + : RewritePattern(tt::DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (computeCapability != 31) + return failure(); + + auto dotOp = dyn_cast(op); + if (!dotOp) + return failure(); + auto oldRetType = cast(dotOp.getType()); + auto oldEncoding = oldRetType.getEncoding(); + if (!oldEncoding || !isa(oldEncoding)) + return failure(); + if (isa(oldEncoding)) + return failure(); + auto aTy = cast(dotOp.getA().getType()); + auto bTy = cast(dotOp.getB().getType()); + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + bool allowTF32 = dotOp.getInputPrecision() == tt::InputPrecision::TF32; + if (aElemTy != bElemTy) + return failure(); + if (!isSupportedWmmaOperandType(aElemTy, allowTF32)) + return failure(); + + auto shapePerCTA = ttg::getShapePerCTA(oldRetType); + if (shapePerCTA.size() != 2) + return failure(); + if (shapePerCTA[0] <= 0 || shapePerCTA[1] <= 0) + return failure(); + + if (aTy.getRank() < 2) + return failure(); + auto kDim = aTy.getShape().back(); + if (kDim <= 0) + return failure(); + + unsigned m = static_cast(shapePerCTA[0]); + unsigned n = static_cast(shapePerCTA[1]); + unsigned k = static_cast(kDim); + unsigned numWarps = ttg::lookupNumWarps(dotOp); + + auto config = selectWmmaConfig(m, n, k, numWarps, aElemTy, allowTF32); + if (!config) + return failure(); + + auto cgaLayout = ttg::getCGALayout(oldEncoding); + auto mmaEnc = ttg::MUSAWmmaEncodingAttr::get( + oldRetType.getContext(), /*versionMajor=*/3, /*versionMinor=*/1, + config->warpsPerCTA, cgaLayout, config->instrShape); + bool useFp32Carrier = computeCapability == 31 && + oldRetType.getElementType().isF16() && + aElemTy.isF16() && bElemTy.isF16(); + Type carrierElemTy = + useFp32Carrier ? rewriter.getF32Type() : oldRetType.getElementType(); + auto newRetType = + RankedTensorType::get(oldRetType.getShape(), carrierElemTy, mmaEnc); + + auto oldAcc = dotOp.getOperand(2); + Value acc = useFp32Carrier ? promoteDotOperand(rewriter, dotOp.getLoc(), + oldAcc, carrierElemTy) + : oldAcc; + bool accIsZero = isZeroConst(oldAcc); + Value newAcc; + if (accIsZero) { + auto zeroElem = rewriter.getZeroAttr(newRetType.getElementType()); + auto zeroTensor = DenseElementsAttr::get(newRetType, zeroElem); + newAcc = arith::ConstantOp::create(rewriter, oldAcc.getLoc(), newRetType, + zeroTensor); + } else { + newAcc = ttg::ConvertLayoutOp::create(rewriter, oldAcc.getLoc(), + newRetType, acc); + } + + auto newAEncoding = ttg::DotOperandEncodingAttr::get( + aTy.getContext(), 0, newRetType.getEncoding(), aElemTy); + auto newAType = + RankedTensorType::get(aTy.getShape(), aElemTy, newAEncoding); + auto newA = ttg::ConvertLayoutOp::create(rewriter, dotOp.getLoc(), newAType, + dotOp.getA()); + + auto newBEncoding = ttg::DotOperandEncodingAttr::get( + bTy.getContext(), 1, newRetType.getEncoding(), bElemTy); + auto newBType = + RankedTensorType::get(bTy.getShape(), bElemTy, newBEncoding); + auto newB = ttg::ConvertLayoutOp::create(rewriter, dotOp.getLoc(), newBType, + dotOp.getB()); + + auto wmmaEltType = toSqmmaOperandEltType(aElemTy, allowTF32); + if (!wmmaEltType) + return failure(); + Value useC = arith::ConstantIntOp::create(rewriter, dotOp.getLoc(), 1, 1); + auto newDot = triton::musa::WmmaDotOp::create( + rewriter, dotOp.getLoc(), newRetType, newA, newB, newAcc, useC, + static_cast(config->instrShape[0]), + static_cast(config->instrShape[1]), + static_cast(config->instrShape[2]), *wmmaEltType, *wmmaEltType, + triton::musa::inferWmmaFragmentLayout(dotOp.getA(), 0), + triton::musa::inferWmmaFragmentLayout(dotOp.getB(), 1), + static_cast(dotOp.getInputPrecision()), + /*maxNumImpreciseAcc=*/0); + newDot->setAttr(kDisableGenericDotPipelineAttr, rewriter.getBoolAttr(true)); + if (!useFp32Carrier) { + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot.getResult()); + return success(); + } + + auto blockedCarrierTy = oldRetType.cloneWith(std::nullopt, carrierElemTy); + Value blockedCarrier = ttg::ConvertLayoutOp::create( + rewriter, dotOp.getLoc(), blockedCarrierTy, newDot.getResult()); + Value truncated = arith::TruncFOp::create(rewriter, dotOp.getLoc(), + oldRetType, blockedCarrier); + rewriter.replaceOp(dotOp, truncated); + return success(); + } + +private: + int computeCapability; +}; + +class BlockedToMUSASqmma : public RewritePattern { +public: + explicit BlockedToMUSASqmma(MLIRContext *context, int computeCapability) + : RewritePattern(tt::DotOp::getOperationName(), 3, context), + computeCapability(computeCapability) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (computeCapability != 31) + return failure(); + + auto dotOp = dyn_cast(op); + if (!dotOp) + return failure(); + auto oldRetType = cast(dotOp.getType()); + auto oldEncoding = oldRetType.getEncoding(); + if (!oldEncoding || !isa(oldEncoding)) + return failure(); + if (isa(oldEncoding)) + return failure(); + if (oldRetType.getRank() != 2) + return failure(); + auto aTy = dyn_cast(dotOp.getA().getType()); + auto bTy = dyn_cast(dotOp.getB().getType()); + if (!aTy || !bTy) + return failure(); + if (aTy.getRank() != 2 || bTy.getRank() != 2) + return failure(); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + if (aElemTy != bElemTy) + return failure(); + + bool allowTF32 = dotOp.getInputPrecision() == tt::InputPrecision::TF32; + if (aElemTy.isF32()) { + if (!allowTF32) + return failure(); + } + if (!isSupportedSqmmaOperandType(aElemTy, allowTF32)) + return failure(); + + auto shapePerCTA = ttg::getShapePerCTA(oldRetType); + if (shapePerCTA.size() != 2) + return failure(); + if (shapePerCTA[0] <= 0 || shapePerCTA[1] <= 0) + return failure(); + + auto kDim = aTy.getShape().back(); + if (kDim <= 0) + return failure(); + + unsigned m = static_cast(shapePerCTA[0]); + unsigned n = static_cast(shapePerCTA[1]); + unsigned k = static_cast(kDim); + unsigned numWarps = ttg::lookupNumWarps(dotOp); + auto config = selectSqmmaConfig(m, n, k, numWarps, aElemTy, allowTF32); + if (!config) + return failure(); + + bool useFp32Carrier = computeCapability == 31 && + oldRetType.getElementType().isF16() && + aElemTy.isF16() && bElemTy.isF16(); + Type carrierElemTy = + useFp32Carrier ? rewriter.getF32Type() : oldRetType.getElementType(); + auto eltTypeC = toSqmmaEltType(carrierElemTy); + auto eltTypeA = toSqmmaOperandEltType(aElemTy, allowTF32); + auto eltTypeB = toSqmmaOperandEltType(bElemTy, allowTF32); + if (!eltTypeC || !eltTypeA || !eltTypeB) + return failure(); + if (!triton::musa::isSupportedSqmma( + *eltTypeA, *eltTypeB, *eltTypeC, config->instrShape[0], + config->instrShape[1], config->instrShape[2])) + return failure(); + + auto cgaLayout = ttg::getCGALayout(oldEncoding); + auto mmaEnc = ttg::MUSASqmmaEncodingAttr::get( + oldRetType.getContext(), /*versionMajor=*/3, /*versionMinor=*/1, + config->warpsPerCTA, cgaLayout, config->instrShape); + auto newRetType = + RankedTensorType::get(oldRetType.getShape(), carrierElemTy, mmaEnc); + + auto oldAcc = dotOp.getOperand(2); + Value acc = useFp32Carrier ? promoteDotOperand(rewriter, dotOp.getLoc(), + oldAcc, carrierElemTy) + : oldAcc; + bool accIsZero = isZeroConst(oldAcc); + Value newAcc; + if (accIsZero) { + auto zeroElem = rewriter.getZeroAttr(newRetType.getElementType()); + auto zeroTensor = DenseElementsAttr::get(newRetType, zeroElem); + newAcc = arith::ConstantOp::create(rewriter, oldAcc.getLoc(), newRetType, + zeroTensor); + } else { + newAcc = ttg::ConvertLayoutOp::create(rewriter, oldAcc.getLoc(), + newRetType, acc); + } + + SqmmaTransLoadKind transLoadKindA = classifySqmmaTransLoad(dotOp.getA()); + SqmmaTransLoadKind transLoadKindB = classifySqmmaTransLoad(dotOp.getB()); + bool allowTransposeA = transLoadKindA == SqmmaTransLoadKind::Descriptor || + (transLoadKindA == SqmmaTransLoadKind::PlainLoad && + shouldAllowSqmmaTranspose(aElemTy)); + bool allowTransposeB = transLoadKindB == SqmmaTransLoadKind::Descriptor || + (transLoadKindB == SqmmaTransLoadKind::PlainLoad && + shouldAllowSqmmaTranspose(bElemTy)); + Value newA = getSharedMemorySqmmaOperand(dotOp.getA(), rewriter, 0, mmaEnc, + allowTransposeA); + Value newB = getSharedMemorySqmmaOperand(dotOp.getB(), rewriter, 1, mmaEnc, + allowTransposeB); + if (!newA || !newB) + return failure(); + + auto accumulationContract = selectSqmmaAccumulationContract( + aElemTy, newRetType.getElementType(), m, n, k, accIsZero, + static_cast(dotOp.getMaxNumImpreciseAcc()), *config); + Value useC = arith::ConstantIntOp::create( + rewriter, dotOp.getLoc(), accumulationContract.useCOperand, 1); + auto newDot = triton::musa::SquadDotOp::create( + rewriter, dotOp.getLoc(), newRetType, newA, newB, newAcc, useC, + static_cast(config->instrShape[0]), + static_cast(config->instrShape[1]), + static_cast(config->instrShape[2]), *eltTypeC, *eltTypeA, + *eltTypeB, inferSqmmaLayout(newA), inferSqmmaLayout(newB), false, + accumulationContract.mode, + static_cast(dotOp.getInputPrecision()), + accumulationContract.mode == + triton::musa::SQMMAAccumulationMode::partial + ? static_cast(dotOp.getMaxNumImpreciseAcc()) + : 0); + newDot->setAttr(kDisableGenericDotPipelineAttr, rewriter.getBoolAttr(true)); + newDot->setAttr("isAsync", rewriter.getBoolAttr(false)); + if (!useFp32Carrier) { + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot.getResult()); + return success(); + } + + auto blockedCarrierTy = oldRetType.cloneWith(std::nullopt, carrierElemTy); + Value blockedCarrier = ttg::ConvertLayoutOp::create( + rewriter, dotOp.getLoc(), blockedCarrierTy, newDot.getResult()); + Value truncated = arith::TruncFOp::create(rewriter, dotOp.getLoc(), + oldRetType, blockedCarrier); + rewriter.replaceOp(dotOp, truncated); + return success(); + } + +private: + int computeCapability; +}; + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUACCELERATEMATMUL +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUAccelerateMatmulPass + : impl::TritonMUSAGPUAccelerateMatmulBase< + TritonMUSAGPUAccelerateMatmulPass> { + using Base::Base; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + int computeCapability = getMusaComputeCapability(mod); + if (computeCapability < 0) + return; + + bool disableSqmma = ::triton::tools::getBoolEnv("DISABLE_SQMMA"); + bool disableWmma = ::triton::tools::getBoolEnv("DISABLE_WMMA"); + + bool sqmmaCandidate = computeCapability >= 31 && !disableSqmma; + // Preserve the 3.6 fallback behavior: descriptor/TME modules may still + // fall back to WMMA when SQMMA predicate matching rejects a dot. + bool wmmaCandidate = computeCapability == 31 && !disableWmma; + + MLIRContext *context = &getContext(); + if (sqmmaCandidate || wmmaCandidate) { + RewritePatternSet patterns(context); + // Keep 3.2-aligned rewrite precedence: SQMMA first, then WMMA. + if (sqmmaCandidate) + patterns.add(context, computeCapability); + if (wmmaCandidate) + patterns.add(context, computeCapability); + + if (applyPatternsGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); + } + + promoteResidualFp8DotForFma(mod); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt new file mode 100644 index 0000000000..07636b167b --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt @@ -0,0 +1,33 @@ +add_triton_library(TritonMUSAGPUTransforms + AccelerateMUSAMatmul.cpp + CanonicalizeSqmmaResultConversions.cpp + ConvertSqmmaToMTGPU.cpp + FinalizeBarriers.cpp + IssueBarrierInsertion.cpp + OptimizeDotOperands.cpp + MarkInplaceLoads.cpp + Pipeline.cpp + OptimizeAccumulatorInit.cpp + OptimizeDescriptorEncoding.cpp + OptimizeSqmmaAccumulatorLayout.cpp + SqmmaPipelineUtils.cpp + TMEPipelineUtils.cpp + TMELowering.cpp + + DEPENDS + TritonMUSAGPUTransformsIncGen + MTGPUTableGen + MTGPUTypesIncGen + TritonGPUIR + + LINK_LIBS PUBLIC + MTGPUIR + MUSAIR + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUIR +) + +target_include_directories(TritonMUSAGPUTransforms PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include) +target_include_directories(TritonMUSAGPUTransforms PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/../../include) diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CanonicalizeSqmmaResultConversions.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CanonicalizeSqmmaResultConversions.cpp new file mode 100644 index 0000000000..eb4d6b5b7a --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CanonicalizeSqmmaResultConversions.cpp @@ -0,0 +1,104 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +static bool preservesSqmmaOperandBoundary(arith::TruncFOp trunc) { + auto contract = triton::musa::recoverUniqueSqmmaConsumerContractFromTensor( + trunc.getResult()); + if (failed(contract)) + return true; + return contract->has_value(); +} + +static bool sinkTruncAfterMmaConvert(ttg::ConvertLayoutOp cvt, + RewriterBase &rewriter) { + auto srcTy = dyn_cast(cvt.getSrc().getType()); + auto dstTy = dyn_cast(cvt.getType()); + if (!srcTy || !dstTy) + return false; + if (!isa_and_nonnull(srcTy.getEncoding())) + return false; + if (!isa_and_nonnull(dstTy.getEncoding())) + return false; + if (!isa(srcTy.getElementType()) || + !isa(dstTy.getElementType())) + return false; + + SmallVector truncUsers; + for (Operation *user : cvt->getUsers()) { + auto trunc = dyn_cast(user); + if (!trunc) + return false; + truncUsers.push_back(trunc); + } + if (truncUsers.empty()) + return false; + + bool changed = false; + for (arith::TruncFOp trunc : truncUsers) { + if (preservesSqmmaOperandBoundary(trunc)) + continue; + auto truncDstTy = dyn_cast(trunc.getType()); + if (!truncDstTy) + continue; + if (!isa(truncDstTy.getElementType())) + continue; + auto mmaTruncTy = RankedTensorType::get( + srcTy.getShape(), truncDstTy.getElementType(), srcTy.getEncoding()); + rewriter.setInsertionPoint(trunc); + Value mmaTrunc = arith::TruncFOp::create(rewriter, trunc.getLoc(), + mmaTruncTy, cvt.getSrc()); + Value cvtAfterTrunc = ttg::ConvertLayoutOp::create(rewriter, trunc.getLoc(), + truncDstTy, mmaTrunc); + rewriter.replaceOp(trunc, cvtAfterTrunc); + changed = true; + } + + if (changed && cvt->use_empty()) + rewriter.eraseOp(cvt); + return changed; +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUCANONICALIZESQMMARESULTCONVERSIONS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUCanonicalizeSqmmaResultConversionsPass + : impl::TritonMUSAGPUCanonicalizeSqmmaResultConversionsBase< + TritonMUSAGPUCanonicalizeSqmmaResultConversionsPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + + for (tt::FuncOp func : mod.getOps()) { + bool changed = true; + while (changed) { + changed = false; + SmallVector cvtOps; + func.walk([&](ttg::ConvertLayoutOp op) { cvtOps.push_back(op); }); + for (ttg::ConvertLayoutOp cvt : cvtOps) { + if (!cvt->getBlock()) + continue; + if (sinkTruncAfterMmaConvert(cvt, rewriter)) + changed = true; + } + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/ConvertSqmmaToMTGPU.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/ConvertSqmmaToMTGPU.cpp new file mode 100644 index 0000000000..5be9f54656 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/ConvertSqmmaToMTGPU.cpp @@ -0,0 +1,339 @@ +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +namespace tt = mlir::triton; + +namespace { + +static RankedTensorType getSqmmaAccumulatorTensorType(Type type) { + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return RankedTensorType(); + return isa_and_nonnull(tensorTy.getEncoding()) + ? tensorTy + : RankedTensorType(); +} + +struct Candidate { + unsigned iterArgIdx; + triton::musa::SquadDotOp sqmma; + triton::mtgpu::SqmmaAccumulatorType carrierType; +}; + +static triton::musa::SquadDotWaitOp getCanonicalExternalFinalWait( + scf::ForOp forOp, const llvm::SmallDenseSet &candidateIdxs) { + auto wait = + dyn_cast_or_null(forOp->getNextNode()); + if (!wait || wait->getBlock() != forOp->getBlock()) + return {}; + + bool dependsOnCandidate = llvm::any_of(wait.getInputs(), [&](Value input) { + auto result = dyn_cast(input); + return result && result.getOwner() == forOp.getOperation() && + candidateIdxs.contains(result.getResultNumber()); + }); + return dependsOnCandidate ? wait : triton::musa::SquadDotWaitOp(); +} + +static Value unwrapYieldedSqmmaValue(Value value) { + auto result = dyn_cast(value); + if (!result) + return value; + auto wait = dyn_cast(result.getOwner()); + if (!wait) + return value; + unsigned idx = result.getResultNumber(); + return idx < wait.getInputs().size() ? wait.getInputs()[idx] : value; +} + +static triton::mtgpu::SQMMAEltType +convertEltType(triton::musa::SQMMAEltType type) { + return static_cast(static_cast(type)); +} + +static triton::mtgpu::SQMMALayout +convertLayout(triton::musa::SQMMALayout layout) { + return static_cast(static_cast(layout)); +} + +static triton::mtgpu::SQMMAAccumulationMode +convertAccumulationMode(triton::musa::SQMMAAccumulationMode mode) { + return static_cast( + static_cast(mode)); +} + +static std::optional getCandidateForIterArg(scf::ForOp forOp, + unsigned iterArgIdx) { + Value iterArg = forOp.getRegionIterArg(iterArgIdx); + auto tensorTy = getSqmmaAccumulatorTensorType(iterArg.getType()); + if (!tensorTy) + return std::nullopt; + + auto yieldOp = dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp || iterArgIdx >= yieldOp.getNumOperands()) + return std::nullopt; + + Value yieldedValue = unwrapYieldedSqmmaValue(yieldOp.getOperand(iterArgIdx)); + auto sqmma = yieldedValue.getDefiningOp(); + if (!sqmma || sqmma.getC() != iterArg) + return std::nullopt; + + return Candidate{ + iterArgIdx, sqmma, + triton::mtgpu::SqmmaAccumulatorType::get(forOp.getContext(), tensorTy)}; +} + +static Value materializeTensorAccumulatorForUse( + Value original, Location loc, IRMapping &mapping, + DenseMap &tensorMaterializations, RewriterBase &rewriter) { + Value mapped = mapping.lookupOrDefault(original); + if (!mapped || mapped.getType() == original.getType()) + return mapped; + + auto originalTensorTy = getSqmmaAccumulatorTensorType(original.getType()); + if (!originalTensorTy || + !isa(mapped.getType())) + return mapped; + + auto it = tensorMaterializations.find(original); + if (it != tensorMaterializations.end()) + return it->second; + + Value unpacked = triton::mtgpu::UnpackSqmmaAccumulatorOp::create( + rewriter, loc, originalTensorTy, mapped); + tensorMaterializations[original] = unpacked; + return unpacked; +} + +static Operation *cloneSqmmaOp(triton::musa::SquadDotOp op, IRMapping &mapping, + DenseMap &tensorMaterializations, + RewriterBase &rewriter) { + auto lookupOrDefault = [&](Value value) -> Value { + return value ? mapping.lookupOrDefault(value) : Value(); + }; + auto materializeTensorOperand = [&](Value value) -> Value { + return materializeTensorAccumulatorForUse(value, op.getLoc(), mapping, + tensorMaterializations, rewriter); + }; + Value mappedA = materializeTensorOperand(op.getA()); + Value mappedB = materializeTensorOperand(op.getB()); + Value mappedC = lookupOrDefault(op.getC()); + Value mappedUseC = lookupOrDefault(op.getUseC()); + if (auto carrierTy = + dyn_cast(mappedC.getType())) { + auto newOp = triton::mtgpu::SqmmaOp::create( + rewriter, op.getLoc(), carrierTy, mappedA, mappedB, mappedC, mappedUseC, + op.getM(), op.getN(), op.getK(), convertEltType(op.getEltTypeC()), + convertEltType(op.getEltTypeA()), convertEltType(op.getEltTypeB()), + convertLayout(op.getLayoutA()), convertLayout(op.getLayoutB()), + op.getIsAsync(), convertAccumulationMode(op.getAccMode()), + op.getInputPrecision(), op.getMaxNumImpreciseAcc()); + newOp->setAttrs(op->getAttrs()); + return newOp; + } + + auto newOp = triton::musa::SquadDotOp::create( + rewriter, op.getLoc(), op.getResult().getType(), mappedA, mappedB, + mappedC, mappedUseC, op.getM(), op.getN(), op.getK(), op.getEltTypeC(), + op.getEltTypeA(), op.getEltTypeB(), op.getLayoutA(), op.getLayoutB(), + op.getIsAsync(), op.getAccMode(), op.getInputPrecision(), + op.getMaxNumImpreciseAcc()); + newOp->setAttrs(op->getAttrs()); + return newOp; +} + +static Operation *cloneSqmmaWaitOp(triton::musa::SquadDotWaitOp op, + IRMapping &mapping, RewriterBase &rewriter) { + SmallVector newInputs; + newInputs.reserve(op.getInputs().size()); + for (Value input : op.getInputs()) + newInputs.push_back(mapping.lookupOrDefault(input)); + auto newOp = + triton::mtgpu::SqmmaWaitOp::create(rewriter, op.getLoc(), newInputs); + newOp->setAttrs(op->getAttrs()); + return newOp; +} + +static bool convertLoopCarriedSqmmaAccumulator(scf::ForOp forOp, + RewriterBase &rewriter) { + SmallVector candidates; + for (unsigned idx = 0; idx < forOp.getNumRegionIterArgs(); ++idx) { + if (auto candidate = getCandidateForIterArg(forOp, idx)) + candidates.push_back(*candidate); + } + if (candidates.empty()) + return false; + + llvm::SmallDenseSet candidateIdxs; + for (const Candidate &candidate : candidates) + candidateIdxs.insert(candidate.iterArgIdx); + triton::musa::SquadDotWaitOp externalFinalWait = + getCanonicalExternalFinalWait(forOp, candidateIdxs); + + Location loc = forOp.getLoc(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(forOp); + + SmallVector initArgs(forOp.getInitArgs()); + for (const Candidate &candidate : candidates) { + initArgs[candidate.iterArgIdx] = + triton::mtgpu::PackSqmmaAccumulatorOp::create( + rewriter, loc, candidate.carrierType, + initArgs[candidate.iterArgIdx]); + } + + scf::ForOp newFor = + scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), initArgs); + + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newFor.getInductionVar()); + for (unsigned idx = 0; idx < forOp.getNumRegionIterArgs(); ++idx) + mapping.map(forOp.getRegionIterArg(idx), newFor.getRegionIterArg(idx)); + + Block &oldBody = forOp.getRegion().front(); + Block &newBody = newFor.getRegion().front(); + rewriter.setInsertionPointToStart(&newBody); + DenseMap tensorMaterializations; + + for (Operation &op : oldBody.without_terminator()) { + if (auto sqmma = dyn_cast(op)) { + Operation *newOp = + cloneSqmmaOp(sqmma, mapping, tensorMaterializations, rewriter); + mapping.map(sqmma.getResult(), newOp->getResult(0)); + continue; + } + if (auto wait = dyn_cast(op)) { + Operation *newOp = cloneSqmmaWaitOp(wait, mapping, rewriter); + for (auto [oldResult, newResult] : + llvm::zip_equal(wait.getResults(), newOp->getResults())) + mapping.map(oldResult, newResult); + continue; + } + + IRMapping opMapping(mapping); + for (Value operand : op.getOperands()) { + Value remapped = materializeTensorAccumulatorForUse( + operand, op.getLoc(), mapping, tensorMaterializations, rewriter); + if (remapped != mapping.lookupOrDefault(operand)) + opMapping.map(operand, remapped); + } + Operation *newOp = rewriter.clone(op, opMapping); + for (auto [oldResult, newResult] : + llvm::zip_equal(op.getResults(), newOp->getResults())) + mapping.map(oldResult, newResult); + } + + auto oldYield = cast(oldBody.getTerminator()); + SmallVector newYieldOperands; + newYieldOperands.reserve(oldYield.getNumOperands()); + for (Value operand : oldYield.getOperands()) + newYieldOperands.push_back(mapping.lookupOrDefault(operand)); + scf::YieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); + + rewriter.setInsertionPointAfter(newFor); + DenseMap externalWaitUnpacks; + if (externalFinalWait) { + auto oldWait = externalFinalWait; + SmallVector newInputs; + newInputs.reserve(oldWait.getInputs().size()); + for (Value input : oldWait.getInputs()) { + Value newInput = input; + if (auto result = dyn_cast(input)) { + if (result.getOwner() == forOp.getOperation()) { + newInput = newFor.getResult(result.getResultNumber()); + } else { + Value remapped = mapping.lookupOrDefault(input); + if (remapped) + newInput = remapped; + } + } + newInputs.push_back(newInput); + } + + auto newWait = triton::mtgpu::SqmmaWaitOp::create( + rewriter, oldWait.getLoc(), newInputs); + newWait->setAttrs(oldWait->getAttrs()); + + for (unsigned idx = 0; idx < oldWait.getNumResults(); ++idx) { + Value oldResult = oldWait.getResult(idx); + Value newResult = newWait.getResult(idx); + Value replacement = newResult; + if (auto oldInput = dyn_cast(oldWait.getInputs()[idx])) { + if (oldInput.getOwner() == forOp.getOperation() && + candidateIdxs.contains(oldInput.getResultNumber())) { + auto it = externalWaitUnpacks.find(newResult); + if (it == externalWaitUnpacks.end()) { + replacement = triton::mtgpu::UnpackSqmmaAccumulatorOp::create( + rewriter, oldWait.getLoc(), oldResult.getType(), newResult); + externalWaitUnpacks[newResult] = replacement; + } else { + replacement = it->second; + } + } + } + rewriter.replaceAllUsesWith(oldResult, replacement); + } + + rewriter.eraseOp(oldWait); + rewriter.setInsertionPointAfter(newWait); + } + + SmallVector replacements; + replacements.reserve(forOp.getNumResults()); + for (unsigned idx = 0; idx < forOp.getNumResults(); ++idx) { + Value result = newFor.getResult(idx); + if (candidateIdxs.contains(idx)) { + auto tensorTy = cast(forOp.getResult(idx).getType()); + result = triton::mtgpu::UnpackSqmmaAccumulatorOp::create( + rewriter, loc, tensorTy, result); + } + replacements.push_back(result); + } + rewriter.replaceOp(forOp, replacements); + return true; +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUCONVERTSQMMATOMTGPU +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUConvertSqmmaToMTGPUPass + : impl::TritonMUSAGPUConvertSqmmaToMTGPUBase< + TritonMUSAGPUConvertSqmmaToMTGPUPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + + for (tt::FuncOp func : mod.getOps()) { + bool changed = true; + while (changed) { + changed = false; + SmallVector loops; + func.walk([&](scf::ForOp loop) { loops.push_back(loop); }); + for (scf::ForOp loop : loops) { + if (!loop->getBlock()) + continue; + if (convertLoopCarriedSqmmaAccumulator(loop, rewriter)) + changed = true; + } + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/FinalizeBarriers.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/FinalizeBarriers.cpp new file mode 100644 index 0000000000..0871b5050f --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/FinalizeBarriers.cpp @@ -0,0 +1,28 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/BarrierUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUFINALIZEBARRIERS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUFinalizeBarriersPass + : impl::TritonMUSAGPUFinalizeBarriersBase< + TritonMUSAGPUFinalizeBarriersPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + + for (tt::FuncOp func : mod.getOps()) + triton::musa::finalizeBarRecord(func, rewriter); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/IssueBarrierInsertion.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/IssueBarrierInsertion.cpp new file mode 100644 index 0000000000..add930a9f1 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/IssueBarrierInsertion.cpp @@ -0,0 +1,129 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +namespace { + +static Value peelSqmmaIssueOperand(Value value) { + llvm::SmallPtrSet visited; + while (value) { + if (!visited.insert(value.getAsOpaquePointer()).second) + break; + + if (auto cvt = value.getDefiningOp()) { + value = cvt.getSrc(); + continue; + } + if (auto load = value.getDefiningOp()) { + value = load.getSrc(); + continue; + } + if (auto view = value.getDefiningOp()) { + value = view.getSrc(); + continue; + } + if (auto view = value.getDefiningOp()) { + value = view.getSrc(); + continue; + } + if (auto view = value.getDefiningOp()) { + value = view.getSrc(); + continue; + } + if (auto view = value.getDefiningOp()) { + value = view.getSrc(); + continue; + } + if (auto view = value.getDefiningOp()) { + value = view.getSrc(); + continue; + } + break; + } + return value; +} + +static bool isIssueBarrier(ttg::BarrierOp barrier) { + return barrier && barrier.hasLocal() && + barrier.getAddrSpace() != ttg::AddrSpace::Local; +} + +static bool +shouldInsertIssueBarrierBefore(triton::musa::AsyncTMECopyLocalToGlobalOp op) { + Operation *prev = op->getPrevNode(); + auto prevBarrier = dyn_cast_or_null(prev); + return !isIssueBarrier(prevBarrier); +} + +static bool shouldInsertIssueBarrierBefore(triton::musa::SquadDotOp op) { + Operation *prev = op->getPrevNode(); + auto prevBarrier = dyn_cast_or_null(prev); + if (isIssueBarrier(prevBarrier)) + return false; + + Value aMemDesc = peelSqmmaIssueOperand(op.getA()); + Value bMemDesc = peelSqmmaIssueOperand(op.getB()); + if (!aMemDesc || !bMemDesc) + return true; + if (!isa(aMemDesc.getType()) || + !isa(bMemDesc.getType())) + return true; + + return triton::musa::needsSqmmaIssueBarrier(aMemDesc, bMemDesc); +} + +static void insertIssueBarrierBefore(Operation *op, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + // MUSA lowers non-local TTG barriers to llvm.musa.barrier0. + ttg::BarrierOp::create(rewriter, op->getLoc(), ttg::AddrSpace::All); +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUISSUEBARRIERINSERTION +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUIssueBarrierInsertionPass + : impl::TritonMUSAGPUIssueBarrierInsertionBase< + TritonMUSAGPUIssueBarrierInsertionPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + SmallVector candidates; + + mod.walk([&](Operation *op) { + if (isa(op)) + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + if (auto store = + dyn_cast(op)) { + if (shouldInsertIssueBarrierBefore(store)) + insertIssueBarrierBefore(op, rewriter); + continue; + } + + if (auto sqmma = dyn_cast(op)) { + if (shouldInsertIssueBarrierBefore(sqmma)) + insertIssueBarrierBefore(op, rewriter); + continue; + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/MarkInplaceLoads.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/MarkInplaceLoads.cpp new file mode 100644 index 0000000000..52fe2e5681 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/MarkInplaceLoads.cpp @@ -0,0 +1,110 @@ +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace tt = mlir::triton; + +#define GEN_PASS_DEF_TRITONMUSAGPUMARKINPLACELOADS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +inline constexpr llvm::StringLiteral kInplaceLoadAttr = + "musa.inplace_load_candidate"; + +static bool +areEquivalentValues(Value lhs, Value rhs, + llvm::DenseMap, bool> &cache); + +static bool +areEquivalentOps(Operation *lhs, Operation *rhs, + llvm::DenseMap, bool> &cache) { + if (lhs == rhs) + return true; + if (!lhs || !rhs) + return false; + + auto checkEquivalent = [&](Value left, Value right) -> LogicalResult { + return success(areEquivalentValues(left, right, cache)); + }; + + return OperationEquivalence::isEquivalentTo( + lhs, rhs, checkEquivalent, /*markEquivalent=*/nullptr, + OperationEquivalence::Flags::IgnoreLocations); +} + +static bool +areEquivalentValues(Value lhs, Value rhs, + llvm::DenseMap, bool> &cache) { + if (lhs == rhs) + return true; + if (!lhs || !rhs || lhs.getType() != rhs.getType()) + return false; + + auto key = std::make_pair(lhs, rhs); + auto reverseKey = std::make_pair(rhs, lhs); + if (auto it = cache.find(key); it != cache.end()) + return it->second; + if (auto it = cache.find(reverseKey); it != cache.end()) + return it->second; + + cache[key] = false; + cache[reverseKey] = false; + + if (auto lhsArg = dyn_cast(lhs)) { + auto rhsArg = dyn_cast(rhs); + bool equivalent = rhsArg && lhsArg.getOwner() == rhsArg.getOwner() && + lhsArg.getArgNumber() == rhsArg.getArgNumber(); + cache[key] = equivalent; + cache[reverseKey] = equivalent; + return equivalent; + } + + bool equivalent = + areEquivalentOps(lhs.getDefiningOp(), rhs.getDefiningOp(), cache); + cache[key] = equivalent; + cache[reverseKey] = equivalent; + return equivalent; +} + +static bool hasSameAddressStoreInFunc(tt::LoadOp loadOp, + ArrayRef storeOps) { + llvm::DenseMap, bool> cache; + Value loadPtr = loadOp.getPtr(); + for (tt::StoreOp storeOp : storeOps) { + if (areEquivalentValues(loadPtr, storeOp.getPtr(), cache)) + return true; + } + return false; +} + +struct TritonMUSAGPUMarkInplaceLoadsPass + : impl::TritonMUSAGPUMarkInplaceLoadsBase< + TritonMUSAGPUMarkInplaceLoadsPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + MLIRContext *ctx = &getContext(); + mod.walk([&](tt::FuncOp funcOp) { + llvm::SmallVector storeOps; + funcOp.walk([&](tt::StoreOp storeOp) { storeOps.push_back(storeOp); }); + if (storeOps.empty()) + return; + + funcOp.walk([&](tt::LoadOp loadOp) { + if (hasSameAddressStoreInFunc(loadOp, storeOps)) + loadOp->setAttr(kInplaceLoadAttr, UnitAttr::get(ctx)); + }); + }); + } +}; + +} // namespace +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeAccumulatorInit.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 0000000000..5c192e37b1 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,212 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +static bool isMusaDotOp(Operation *op) { + return isa(op) && + isa(op); +} + +static bool supportsAccumulatorInitOptimization(Operation *op) { + if (auto sqmma = dyn_cast(op)) { + return !sqmma.needsPartialAccumulator(); + } + if (auto wmma = dyn_cast(op)) + return !wmma.needsPartialAccumulator(); + return false; +} + +static Value getAccumulatorValue(Operation *op) { + if (auto sqmma = dyn_cast(op)) + return sqmma.getC(); + if (auto wmma = dyn_cast(op)) + return wmma.getC(); + return {}; +} + +static Value getUseCValue(Operation *op) { + if (auto sqmma = dyn_cast(op)) + return sqmma.getUseC(); + if (auto wmma = dyn_cast(op)) + return wmma.getUseC(); + return {}; +} + +static void setUseCValue(Operation *op, Value useC) { + if (auto sqmma = dyn_cast(op)) { + sqmma.getUseCMutable().assign(useC); + return; + } + if (auto wmma = dyn_cast(op)) { + wmma.getUseCMutable().assign(useC); + return; + } + llvm_unreachable("unexpected MUSA dot op"); +} + +static bool isConstantZeroTensor(Value v) { + return matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat()); +} + +static std::optional> +findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) + loopArgIsZero = true; + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) + return std::nullopt; + if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = cast(v).getResultNumber(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUOPTIMIZEACCUMULATORINIT +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUOptimizeAccumulatorInitPass + : impl::TritonMUSAGPUOptimizeAccumulatorInitBase< + TritonMUSAGPUOptimizeAccumulatorInitPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + SmallVector dotOps; + mod.walk([&](Operation *op) { + if (isMusaDotOp(op) && supportsAccumulatorInitOptimization(op)) + dotOps.push_back(op); + }); + + for (Operation *dotOp : dotOps) { + auto forOp = dyn_cast(dotOp->getParentOp()); + if (!forOp) + continue; + + Location loc = dotOp->getLoc(); + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); + Value vFalse = + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)); + + Value accUse = getAccumulatorValue(dotOp); + if (!accUse) + continue; + + Value useCValue = getUseCValue(dotOp); + if (useCValue) { + auto useCConst = tt::getBoolFromConstant(useCValue); + if (!useCConst || !*useCConst) + continue; + } + + if (isConstantZeroTensor(accUse)) { + setUseCValue(dotOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, forOp, loopArgIsZero); + if (!zeroInitOp && !loopArgIsZero) + continue; + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + forOp = addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue}); + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + if (zeroInitOp) { + Value condition; + Value oldValue; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeDot = zeroInitOp->first->isBeforeInBlock(dotOp); + Value prevFlagValue = zeroingBeforeDot ? loopArgFlagValue : vTrue; + auto selectFlagOp = arith::SelectOp::create( + rewriter, loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseCValue(dotOp, zeroingBeforeDot ? selectFlagOp : loopArgFlagValue); + + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeDot ? vTrue : selectFlagOp}); + + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } else if (loopArgIsZero) { + setUseCValue(dotOp, loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), vTrue); + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDescriptorEncoding.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDescriptorEncoding.cpp new file mode 100644 index 0000000000..7ab6ea41ba --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDescriptorEncoding.cpp @@ -0,0 +1,587 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MemDescUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PriorityWorklist.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace { + +inline constexpr llvm::StringLiteral kHostTensorDescABIArgsAttr = + "musa.host_tensordesc_abi_args"; + +struct UseInfo { + TypedValue descriptor; + Operation *use; + Attribute desiredSharedEncoding; + SmallVector shape; + SmallVector allocShape; + ttg::CGAEncodingAttr cgaLayout; +}; + +static bool isTMACompatibleEncoding(Attribute enc) { + if (isa_and_nonnull(enc)) + return true; + if (auto nvmma = dyn_cast_or_null(enc)) + return !nvmma.getTransposed() && !nvmma.getFp4Padded(); + return false; +} + +static Attribute findDirectLoadEncodingFromUsers(Operation *op) { + for (Operation *user : op->getUsers()) { + if (auto alloc = dyn_cast(user)) { + auto enc = alloc.getType().getEncoding(); + if (isTMACompatibleEncoding(enc)) + return enc; + } else if (auto store = dyn_cast(user)) { + auto dstTy = dyn_cast(store.getDst().getType()); + auto enc = dstTy ? dstTy.getEncoding() : Attribute(); + if (isTMACompatibleEncoding(enc)) + return enc; + } + } + return {}; +} + +static Attribute canonicalizeDesiredSharedEncoding( + Operation *op, TypedValue desc, Attribute encoding, + ArrayRef shape, unsigned numCTAs) { + if (!encoding) + return {}; + auto blockTy = desc.getType().getSignlessBlockType(); + auto usageShape = shape.empty() + ? SmallVector(blockTy.getShape().begin(), + blockTy.getShape().end()) + : SmallVector(shape.begin(), shape.end()); + auto usageTy = RankedTensorType::get(usageShape, blockTy.getElementType()); + auto canonical = + triton::musa::tryMapTMECompatibleSharedEncodingToCanonicalSwizzled( + op, usageTy, encoding, usageShape, numCTAs); + return canonical ? Attribute(*canonical) : Attribute(); +} + +static Attribute findDescriptorLoadEncoding(tt::DescriptorLoadOp loadOp) { + auto landingTy = + triton::musa::getUniqueCanonicalLandingRootMemDescType(loadOp); + if (!landingTy) + return {}; + Attribute enc = landingTy->getEncoding(); + if (isTMACompatibleEncoding(enc)) + return enc; + return {}; +} + +static SmallVector +findDescriptorLoadAllocShape(tt::DescriptorLoadOp loadOp) { + auto landingTy = + triton::musa::getUniqueCanonicalLandingRootMemDescType(loadOp); + if (!landingTy) + return {}; + auto allocShape = landingTy->getAllocShape(); + if (allocShape.empty()) + return {}; + auto rank = loadOp.getDesc().getType().getBlockType().getRank(); + SmallVector result(rank, 1); + assert(allocShape.size() <= static_cast(rank)); + auto rankDiff = rank - static_cast(allocShape.size()); + std::copy(allocShape.begin(), allocShape.end(), result.begin() + rankDiff); + return result; +} + +static SmallVector expandToRank(ArrayRef shape, int rank) { + SmallVector result(rank, 1); + assert(shape.size() <= static_cast(rank)); + auto rankDiff = rank - static_cast(shape.size()); + std::copy(shape.begin(), shape.end(), result.begin() + rankDiff); + return result; +} + +static std::optional getUseInfo(Operation *op, unsigned numCTAs) { + UseInfo info; + info.use = op; + if (auto load = dyn_cast(op)) { + info.descriptor = load.getDesc(); + auto shape = load.getResult().getType().getShape(); + auto rank = load.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + auto rawEncoding = findDescriptorLoadEncoding(load); + if (!rawEncoding) + rawEncoding = load.getDesc().getType().getBlockType().getEncoding(); + info.desiredSharedEncoding = canonicalizeDesiredSharedEncoding( + load, info.descriptor, rawEncoding, info.shape, numCTAs); + info.allocShape = findDescriptorLoadAllocShape(load); + auto encoding = + info.desiredSharedEncoding ? info.desiredSharedEncoding : rawEncoding; + info.cgaLayout = + encoding ? ttg::getCGALayout(encoding) : ttg::CGAEncodingAttr(); + return info; + } + if (auto gather = dyn_cast(op)) { + info.descriptor = gather.getDesc(); + auto shape = gather.getResult().getType().getShape(); + auto rank = gather.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + auto rawEncoding = findDirectLoadEncodingFromUsers(op); + if (!rawEncoding) + rawEncoding = gather.getDesc().getType().getBlockType().getEncoding(); + info.desiredSharedEncoding = canonicalizeDesiredSharedEncoding( + gather, info.descriptor, rawEncoding, info.shape, numCTAs); + auto encoding = + info.desiredSharedEncoding ? info.desiredSharedEncoding : rawEncoding; + info.cgaLayout = + encoding ? ttg::getCGALayout(encoding) : ttg::CGAEncodingAttr(); + return info; + } + if (auto store = dyn_cast(op)) { + info.descriptor = store.getDesc(); + auto shape = store.getSrc().getType().getShape(); + auto rank = store.getDesc().getType().getBlockType().getRank(); + info.shape = expandToRank(shape, rank); + auto rawEncoding = store.getDesc().getType().getBlockType().getEncoding(); + info.desiredSharedEncoding = + canonicalizeDesiredSharedEncoding(store.getOperation(), info.descriptor, + rawEncoding, info.shape, numCTAs); + auto encoding = + info.desiredSharedEncoding ? info.desiredSharedEncoding : rawEncoding; + info.cgaLayout = + encoding ? ttg::getCGALayout(encoding) : ttg::CGAEncodingAttr(); + return info; + } + return std::nullopt; +} + +struct EncodingInfo { + Attribute desiredEncoding; + ttg::CGAEncodingAttr cgaLayout; + SmallVector shape; + SmallVector allocShape; + bool forcedToDefault = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + cgaLayout == other.cgaLayout && + forcedToDefault == other.forcedToDefault && shape == other.shape && + allocShape == other.allocShape; + } +}; + +} // namespace + +template <> struct std::hash { + size_t operator()(const EncodingInfo &einfo) const { + return llvm::hash_combine(einfo.desiredEncoding, einfo.cgaLayout, + einfo.forcedToDefault, + llvm::ArrayRef(einfo.shape), + llvm::ArrayRef(einfo.allocShape)); + } +}; + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUOPTIMIZEDESCRIPTORENCODING +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +static const EncodingInfo *internEncoding(std::unordered_set &set, + EncodingInfo info) { + return &*set.insert(std::move(info)).first; +} + +static EncodingInfo combineEncodings(const EncodingInfo &lhs, + const EncodingInfo &rhs, unsigned rank) { + EncodingInfo result; + result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault; + if (result.forcedToDefault) + return result; + + if (lhs.shape.empty() || lhs.shape == rhs.shape) + result.shape = rhs.shape; + else if (rhs.shape.empty()) + result.shape = lhs.shape; + else { + assert(lhs.shape.size() == rhs.shape.size()); + result.shape.reserve(lhs.shape.size()); + for (auto [lhsDim, rhsDim] : llvm::zip_equal(lhs.shape, rhs.shape)) + result.shape.push_back(std::min(lhsDim, rhsDim)); + } + + if (lhs.allocShape.empty() || lhs.allocShape == rhs.allocShape) + result.allocShape = rhs.allocShape; + else if (rhs.allocShape.empty()) + result.allocShape = lhs.allocShape; + else + result.forcedToDefault = true; + + llvm::SetVector cgaLayouts; + if (lhs.cgaLayout) + cgaLayouts.insert(lhs.cgaLayout); + if (rhs.cgaLayout) + cgaLayouts.insert(rhs.cgaLayout); + + auto getDefaultLayout = [&](ttg::CGAEncodingAttr encoding) { + auto *ctx = encoding.getContext(); + auto kBlock = StringAttr::get(ctx, "block"); + auto dims = triton::standardOutDimNames(ctx, rank); + auto numCTAs = encoding.getLinearLayout().getInDimSize(kBlock); + triton::LinearLayout llDefault; + for (unsigned i = 0; i < rank - 1; ++i) + llDefault *= triton::LinearLayout::identity1D(1, kBlock, dims[i]); + llDefault *= triton::LinearLayout::identity1D(numCTAs, kBlock, dims.back()); + return ttg::CGAEncodingAttr::get(ctx, llDefault); + }; + + switch (cgaLayouts.size()) { + case 2: + result.cgaLayout = getDefaultLayout(lhs.cgaLayout); + break; + case 1: + result.cgaLayout = cgaLayouts[0]; + break; + default: + break; + } + + llvm::SetVector desiredEncodings; + if (lhs.desiredEncoding) + desiredEncodings.insert(lhs.desiredEncoding); + if (rhs.desiredEncoding) + desiredEncodings.insert(rhs.desiredEncoding); + + switch (desiredEncodings.size()) { + case 2: + result.forcedToDefault = true; + break; + case 1: + result.desiredEncoding = desiredEncodings[0]; + break; + default: + break; + } + return result; +} + +static tt::TensorDescType +getTensorDescTypeWithEncoding(Operation *op, RankedTensorType blockTy, + Attribute encoding) { + auto sharedEnc = cast(encoding); + auto updatedEncoding = ttng::updateEncodingForShape(op, sharedEnc, blockTy); + return tt::TensorDescType::get(blockTy.getContext(), + blockTy.cloneWithEncoding(updatedEncoding)); +} + +static void updateFunctionType(tt::FuncOp func) { + SmallVector argTys(func.getBody().front().getArgumentTypes()); + SmallVector resultTys(func.getResultTypes()); + func.setFunctionType(FunctionType::get(func.getContext(), argTys, resultTys)); +} + +static void assignMemoryLayouts(tt::FuncOp func) { + std::unordered_set encodings; + llvm::MapVector, const EncodingInfo *> + valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + unsigned numCTAs = + std::max(1u, static_cast(ttg::lookupNumCTAs(func))); + + auto updateEncoding = [&](ArrayRef descValues, EncodingInfo info) { + for (Value value : descValues) { + auto typedVal = cast>(value); + auto it = valueToEncodingInfo.find(typedVal); + if (it != valueToEncodingInfo.end()) { + info = combineEncodings(*it->second, info, + typedVal.getType().getBlockType().getRank()); + } + } + + auto *einfo = internEncoding(encodings, std::move(info)); + for (Value value : descValues) { + auto typedVal = cast>(value); + auto res = valueToEncodingInfo.try_emplace(typedVal, einfo); + if (res.second) { + worklist.insert(typedVal); + } else if (res.first->second != einfo) { + res.first->second = einfo; + worklist.insert(typedVal); + } + } + }; + + bool isKernel = triton::isKernel(func); + for (auto blockArg : func.getBlocks().front().getArguments()) { + if (auto desc = dyn_cast>(blockArg)) { + updateEncoding( + {desc}, EncodingInfo{{}, {}, {}, {}, /*forcedToDefault=*/!isKernel}); + } + } + + func.walk([&](Operation *op) { + if (auto info = getUseInfo(op, numCTAs)) { + updateEncoding(info->descriptor, + EncodingInfo{info->desiredSharedEncoding, info->cgaLayout, + info->shape, info->allocShape}); + return; + } + + bool forcedToDefault = + isa(op); + auto *einfo = internEncoding(encodings, + EncodingInfo{{}, {}, {}, {}, forcedToDefault}); + + auto seedEncoding = [&](Value value) { + auto typedVal = cast>(value); + valueToEncodingInfo.try_emplace(typedVal, einfo); + if (forcedToDefault) + worklist.insert(typedVal); + }; + + for (Value result : op->getResults()) { + if (auto desc = dyn_cast>(result)) + seedEncoding(desc); + } + for (Value operand : op->getOperands()) { + if (auto desc = dyn_cast>(operand)) + seedEncoding(desc); + } + }); + + while (!worklist.empty()) { + auto desc = worklist.pop_back_val(); + + for (OpOperand &use : desc.getUses()) { + Operation *op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + updateEncoding(triton::getTiedArgs(op, use.getOperandNumber() - offset), + EncodingInfo{}); + } else if (isa(op)) { + updateEncoding( + triton::getTiedArgs(op->getParentOp(), use.getOperandNumber()), + EncodingInfo{}); + } + } + + if (auto opResult = dyn_cast(desc)) { + Operation *definingOp = opResult.getOwner(); + if (isa(definingOp)) + updateEncoding( + triton::getTiedArgs(definingOp, opResult.getResultNumber()), + EncodingInfo{}); + } else if (auto blockArg = dyn_cast(desc)) { + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + updateEncoding( + triton::getTiedArgs(parentOp, blockArg.getArgNumber() - offset), + EncodingInfo{}); + } + } + } + + for (auto &[desc, einfo] : valueToEncodingInfo) { + auto existingTy = desc.getType().getBlockType(); + auto preferredCGA = einfo->cgaLayout; + if (!preferredCGA && existingTy.getEncoding()) + preferredCGA = ttg::getCGALayout(existingTy.getEncoding()); + auto newEncoding = + triton::musa::normalizeTMECompatibleSharedEncodingOrDefault( + desc.getDefiningOp(), existingTy, einfo->desiredEncoding, + preferredCGA, einfo->shape, einfo->allocShape, numCTAs); + desc.setType(getTensorDescTypeWithEncoding(desc.getDefiningOp(), existingTy, + newEncoding)); + } + + SmallVector argTys(func.getBody().front().getArgumentTypes()); + SmallVector resultTys(func.getResultTypes()); + for (auto [i, resultTy] : llvm::enumerate(resultTys)) { + auto descTy = dyn_cast(resultTy); + if (!descTy) + continue; + auto encoding = triton::musa::getDefaultTMECompatibleSharedEncoding( + descTy.getBlockType(), {}, {}, numCTAs); + resultTys[i] = + getTensorDescTypeWithEncoding(nullptr, descTy.getBlockType(), encoding); + } + func.setFunctionType(FunctionType::get(func.getContext(), argTys, resultTys)); +} + +static bool matchesExpandedTensorDescABI(Block &entry, unsigned descArgIdx, + unsigned rank) { + unsigned suffixCount = rank * 2; + if (descArgIdx + suffixCount >= entry.getNumArguments()) + return false; + for (unsigned i = 0; i < rank; ++i) { + Type argTy = entry.getArgument(descArgIdx + 1 + i).getType(); + if (!argTy.isSignlessInteger(32)) + return false; + } + for (unsigned i = 0; i < rank; ++i) { + Type argTy = entry.getArgument(descArgIdx + 1 + rank + i).getType(); + if (!argTy.isSignlessInteger(64)) + return false; + } + return true; +} + +static void compactUnusedHostTensorDescABI(tt::FuncOp func) { + if (!triton::isKernel(func)) + return; + + Block &entry = func.getBody().front(); + SmallVector descABIArgs; + llvm::BitVector eraseMask(entry.getNumArguments()); + auto emptyAttr = DictionaryAttr::get(func.getContext()); + SmallVector allArgAttrs; + allArgAttrs.reserve(entry.getNumArguments()); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + DictionaryAttr attr = func.getArgAttrDict(i); + allArgAttrs.push_back(attr ? attr : emptyAttr); + } + + for (unsigned i = 0; i < entry.getNumArguments();) { + auto descTy = dyn_cast(entry.getArgument(i).getType()); + if (!descTy) { + ++i; + continue; + } + + unsigned rank = std::max(1, descTy.getBlockType().getRank()); + unsigned abiArgs = 1; + if (matchesExpandedTensorDescABI(entry, i, rank)) { + abiArgs = 1 + 2 * rank; + bool suffixUnused = true; + for (unsigned j = 0; j < 2 * rank; ++j) { + if (!entry.getArgument(i + 1 + j).use_empty()) { + suffixUnused = false; + break; + } + } + if (suffixUnused) { + abiArgs = 1; + for (unsigned j = 0; j < 2 * rank; ++j) + eraseMask.set(i + 1 + j); + } + } + descABIArgs.push_back(abiArgs); + i += abiArgs == 1 ? 1 : abiArgs; + } + + if (eraseMask.any()) { + entry.eraseArguments(eraseMask); + updateFunctionType(func); + if (!allArgAttrs.empty()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(entry.getNumArguments()); + for (auto [i, attr] : llvm::enumerate(allArgAttrs)) { + if (!eraseMask.test(i)) + newArgAttrs.push_back(attr); + } + func.setAllArgAttrs(newArgAttrs); + } + } + + unsigned descIdx = 0; + for (auto [argIdx, arg] : llvm::enumerate(entry.getArguments())) { + if (!isa(arg.getType())) + continue; + func.setArgAttr(argIdx, kHostTensorDescABIArgsAttr, + IntegerAttr::get(IntegerType::get(func.getContext(), 32), + descABIArgs[descIdx++])); + } +} + +// Rewrite: +// convert_layout(fp_to_fp(x)) #dot_operand +// -> +// fp_to_fp(convert_layout(x) #dot_operand) +// +// For descriptor-driven fp8/f16 dot paths this keeps conversion adjacent to +// descriptor loads and avoids a later blocked->dot shmem staging round-trip. +class HoistFpToFpAcrossDotOperandConvert + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttg::ConvertLayoutOp cvt, + PatternRewriter &rewriter) const override { + auto dstTy = dyn_cast(cvt.getType()); + if (!dstTy) + return failure(); + if (!isa_and_nonnull(dstTy.getEncoding())) + return failure(); + + auto fpToFp = cvt.getSrc().getDefiningOp(); + if (!fpToFp) + return failure(); + + auto midTy = dyn_cast(fpToFp.getType()); + auto srcTy = dyn_cast(fpToFp.getSrc().getType()); + if (!midTy || !srcTy) + return failure(); + if (midTy.getShape() != dstTy.getShape() || + srcTy.getShape() != dstTy.getShape()) + return failure(); + if (midTy.getElementType() != dstTy.getElementType() || + srcTy.getElementType() == dstTy.getElementType()) + return failure(); + + auto newCvtTy = RankedTensorType::get( + dstTy.getShape(), srcTy.getElementType(), dstTy.getEncoding()); + rewriter.setInsertionPoint(cvt); + Value newCvt = ttg::ConvertLayoutOp::create(rewriter, cvt.getLoc(), + newCvtTy, fpToFp.getSrc()); + Value newFpToFp = + tt::FpToFpOp::create(rewriter, fpToFp.getLoc(), dstTy, newCvt); + rewriter.replaceOp(cvt, newFpToFp); + return success(); + } +}; + +} // namespace + +struct TritonMUSAGPUOptimizeDescriptorEncodingPass + : impl::TritonMUSAGPUOptimizeDescriptorEncodingBase< + TritonMUSAGPUOptimizeDescriptorEncodingPass> { + using Base::Base; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + for (Operation &op : mod.getBodyRegion().front()) { + auto func = dyn_cast(&op); + if (!func) + continue; + assignMemoryLayouts(func); + compactUnusedHostTensorDescABI(func); + } + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDotOperands.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDotOperands.cpp new file mode 100644 index 0000000000..349687f0bd --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeDotOperands.cpp @@ -0,0 +1,237 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUOPTIMIZEDOTOPERANDS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +static bool isDescriptorTensorViewChain(Value value) { + while (value) { + if (value.getDefiningOp()) + return true; + if (auto transOp = value.getDefiningOp()) { + value = transOp.getSrc(); + continue; + } + if (auto reshapeOp = value.getDefiningOp()) { + value = reshapeOp.getSrc(); + continue; + } + return false; + } + return false; +} + +static SmallVector +invertTrailingPermutation(ArrayRef allocShape, + ArrayRef order) { + SmallVector result; + auto rank = static_cast(order.size()); + if (allocShape.size() < rank) + return result; + result.assign(allocShape.begin(), allocShape.end() - rank); + + SmallVector tail(allocShape.end() - rank, allocShape.end()); + SmallVector inverse(rank); + for (auto [idx, permutedIdx] : llvm::enumerate(order)) + inverse[permutedIdx] = tail[idx]; + result.append(inverse.begin(), inverse.end()); + return result; +} + +class SwizzleShmemConvert : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttg::ConvertLayoutOp cvtOp, + PatternRewriter &rewriter) const override { + if (!cvtOp->hasOneUse() || !isa(cvtOp->use_begin()->getOwner())) + return failure(); + auto trans = cvtOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef{1, 0} || + !trans->hasOneUse()) + return failure(); + + RankedTensorType srcTy = trans.getSrc().getType(); + if (auto srcCvt = trans.getSrc().getDefiningOp()) + srcTy = srcCvt.getSrc().getType(); + + RankedTensorType sharedLoadTy = cvtOp.getType(); + auto cvtEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!cvtEncoding) + return failure(); + + auto *ctx = getContext(); + auto oldCGALayout = ttg::getCGALayout(srcTy.getEncoding()); + auto newLl = + transposeLinearLayout(oldCGALayout.getLinearLayout(), trans.getOrder()); + auto newCGALayout = ttg::CGAEncodingAttr::get(ctx, std::move(newLl)); + auto newInnerCvtEnc = triton::musa::composeMusaOperandSharedLayout( + cvtEncoding, srcTy.getShape(), + /*order=*/ttg::getOrderForMemory(srcTy), newCGALayout, + srcTy.getElementType(), + /*needTrans=*/true); + if (!newInnerCvtEnc) + return failure(); + + rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx); + auto alloc = ttg::LocalAllocOp::create( + rewriter, trans.getLoc(), + ttg::MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + *newInnerCvtEnc, sharedMemorySpace), + trans.getSrc()); + auto newTrans = ttg::MemDescTransOp::create(rewriter, trans.getLoc(), alloc, + ArrayRef({1, 0})); + auto localLoadOp = ttg::LocalLoadOp::create(rewriter, trans.getLoc(), + sharedLoadTy, newTrans); + rewriter.modifyOpInPlace(cvtOp, [&]() { + cvtOp.getSrcMutable().assign(localLoadOp.getResult()); + }); + return success(); + } +}; + +class NormalizeDescriptorTransLocalAlloc + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttg::LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp.getSrc()) + return failure(); + auto transOp = allocOp.getSrc().getDefiningOp(); + if (!transOp || !isDescriptorTensorViewChain(transOp.getSrc())) + return failure(); + + auto allocTy = dyn_cast(allocOp.getType()); + auto srcTy = dyn_cast(transOp.getSrc().getType()); + if (!allocTy || !srcTy || !allocTy.getEncoding()) + return failure(); + + Dialect &dialect = allocTy.getEncoding().getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (!inferLayoutInterface) + return failure(); + + Attribute sourceEncoding; + if (failed(inferLayoutInterface->inferTransOpEncoding( + allocTy.getEncoding(), srcTy.getShape(), transOp.getOrder(), + sourceEncoding, allocOp.getLoc()))) { + return failure(); + } + + SmallVector sourceAllocShape = + invertTrailingPermutation(allocTy.getAllocShape(), transOp.getOrder()); + if (sourceAllocShape.empty()) + return failure(); + + auto sourceTy = ttg::MemDescType::get( + srcTy.getShape(), srcTy.getElementType(), sourceEncoding, + allocTy.getMemorySpace(), allocTy.getMutableMemory(), sourceAllocShape); + + auto newAlloc = ttg::LocalAllocOp::create(rewriter, allocOp.getLoc(), + sourceTy, transOp.getSrc()); + triton::musa::copySqmmaAttrs(allocOp.getOperation(), + newAlloc.getOperation()); + + Value transposed = ttg::MemDescTransOp::create( + rewriter, transOp.getLoc(), newAlloc, transOp.getOrder()); + triton::musa::copySqmmaAttrs(allocOp.getOperation(), + transposed.getDefiningOp()); + rewriter.replaceOp(allocOp, transposed); + return success(); + } +}; + +class NormalizeDescriptorReshapeLocalAlloc + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttg::LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp.getSrc()) + return failure(); + auto reshapeOp = allocOp.getSrc().getDefiningOp(); + if (!reshapeOp || !isDescriptorTensorViewChain(reshapeOp.getSrc())) + return failure(); + + auto allocTy = dyn_cast(allocOp.getType()); + auto srcTy = dyn_cast(reshapeOp.getSrc().getType()); + if (!allocTy || !srcTy) + return failure(); + + ttg::MemDescType sourceTy; + if (failed(ttg::MemDescReshapeOp::inferReturnTypes( + getContext(), allocOp.getLoc(), allocTy, srcTy.getShape(), + sourceTy))) { + return failure(); + } + + auto newAlloc = ttg::LocalAllocOp::create(rewriter, allocOp.getLoc(), + sourceTy, reshapeOp.getSrc()); + triton::musa::copySqmmaAttrs(allocOp.getOperation(), + newAlloc.getOperation()); + + Value reshaped = ttg::MemDescReshapeOp::create( + rewriter, reshapeOp.getLoc(), newAlloc, allocTy.getShape()); + if (reshaped.getType() != allocOp.getType()) { + reshaped.getDefiningOp()->erase(); + newAlloc.erase(); + return failure(); + } + triton::musa::copySqmmaAttrs(allocOp.getOperation(), + reshaped.getDefiningOp()); + rewriter.replaceOp(allocOp, reshaped); + return success(); + } +}; + +struct TritonMUSAGPUOptimizeDotOperandsPass + : impl::TritonMUSAGPUOptimizeDotOperandsBase< + TritonMUSAGPUOptimizeDotOperandsPass> { + using Base::Base; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + OpPassManager pm; + pm.addPass(mlir::createCanonicalizerPass()); + if (failed(runPipeline(pm, mod))) + return signalPassFailure(); + + RewritePatternSet patterns(context); + patterns.add(context); + ttg::ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeSqmmaAccumulatorLayout.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeSqmmaAccumulatorLayout.cpp new file mode 100644 index 0000000000..d8fe7179c4 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/OptimizeSqmmaAccumulatorLayout.cpp @@ -0,0 +1,157 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/BitVector.h" + +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +static Value unwrapSqmmaWaitResult(Value value) { + while (auto result = dyn_cast(value)) { + auto wait = dyn_cast(result.getOwner()); + if (!wait) + break; + unsigned idx = result.getResultNumber(); + if (idx >= wait.getInputs().size()) + break; + value = wait.getInputs()[idx]; + } + return value; +} + +static bool isSqmmaAccumulatorLoopArg(Value loopArg) { + for (Operation *user : loopArg.getUsers()) { + auto sqmma = dyn_cast(user); + if (!sqmma) + continue; + if (sqmma.getC() == loopArg) + return true; + } + return false; +} + +static bool sinkLoopCarriedSqmmaAccumulatorConvert(scf::ForOp &forOp, + RewriterBase &rewriter) { + auto yieldOp = dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) + return false; + + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + if (yieldOp.getNumOperands() != numIterArgs) + return false; + + struct Candidate { + unsigned blockedIdx; + unsigned mmaIdx; + Operation *waitToCleanup = nullptr; + }; + + SmallVector candidates; + llvm::SmallBitVector blockedIdxSeen(numIterArgs, false); + for (unsigned blockedIdx = 0; blockedIdx < numIterArgs; ++blockedIdx) { + auto cvt = + yieldOp.getOperand(blockedIdx).getDefiningOp(); + if (!cvt || cvt->getBlock() != forOp.getBody()) + continue; + + Value mmaValue = unwrapSqmmaWaitResult(cvt.getSrc()); + std::optional mmaIdx; + for (unsigned idx = 0; idx < numIterArgs; ++idx) { + if (unwrapSqmmaWaitResult(yieldOp.getOperand(idx)) == mmaValue) { + mmaIdx = idx; + break; + } + } + if (!mmaIdx || *mmaIdx == blockedIdx) + continue; + if (blockedIdxSeen.test(blockedIdx)) + continue; + if (forOp.getResult(blockedIdx).getType() != cvt.getType()) + continue; + if (forOp.getResult(*mmaIdx).getType() != mmaValue.getType()) + continue; + if (!isSqmmaAccumulatorLoopArg(forOp.getRegionIterArg(*mmaIdx))) + continue; + + blockedIdxSeen.set(blockedIdx); + Operation *waitToCleanup = nullptr; + if (auto waitResult = dyn_cast(cvt.getSrc())) + waitToCleanup = + dyn_cast(waitResult.getOwner()) + .getOperation(); + + candidates.push_back({blockedIdx, *mmaIdx, waitToCleanup}); + } + + if (candidates.empty()) + return false; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(forOp); + for (const Candidate &candidate : candidates) { + Value converted = ttg::ConvertLayoutOp::create( + rewriter, forOp.getLoc(), + forOp.getResult(candidate.blockedIdx).getType(), + forOp.getResult(candidate.mmaIdx)); + forOp.getResult(candidate.blockedIdx).replaceAllUsesWith(converted); + } + + llvm::BitVector eraseBits(numIterArgs); + for (const Candidate &candidate : candidates) + eraseBits.set(candidate.blockedIdx); + eraseLoopCarriedValues(forOp, eraseBits); + + llvm::SmallDenseSet waitsToCleanup; + for (const Candidate &candidate : candidates) { + if (candidate.waitToCleanup) + waitsToCleanup.insert(candidate.waitToCleanup); + } + for (Operation *waitOp : waitsToCleanup) { + if (waitOp->use_empty()) + rewriter.eraseOp(waitOp); + } + return true; +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUOPTIMIZESQMMAACCUMULATORLAYOUT +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUOptimizeSqmmaAccumulatorLayoutPass + : impl::TritonMUSAGPUOptimizeSqmmaAccumulatorLayoutBase< + TritonMUSAGPUOptimizeSqmmaAccumulatorLayoutPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + + for (tt::FuncOp func : mod.getOps()) { + bool changed = true; + while (changed) { + changed = false; + SmallVector forOps; + func.walk([&](scf::ForOp loop) { forOps.push_back(loop); }); + for (scf::ForOp loop : forOps) { + if (!loop->getBlock()) + continue; + if (sinkLoopCarriedSqmmaAccumulatorConvert(loop, rewriter)) + changed = true; + } + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/Pipeline.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/Pipeline.cpp new file mode 100644 index 0000000000..7996497768 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/Pipeline.cpp @@ -0,0 +1,1409 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "SqmmaPipelineUtils.h" +#include "TMEPipelineUtils.h" +#include "TritonMUSACommon/BarrierUtils.h" +#include "TritonMUSACommon/MemDescUtils.h" +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/LoopPeeling.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace { + +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) + mask = tt::SplatOp::create(rewriter, loc, maskType, pred); + if (currentMask) + mask = arith::AndIOp::create(rewriter, loc, mask, currentMask); + return mask; +} + +static Operation *predicateVoidOpWithIf(RewriterBase &rewriter, Operation *op, + Value pred) { + if (isConstantIntValue(pred, 1)) + return op; + if (!op->getResults().empty()) { + op->emitOpError("MUSA pipeliner can only branch-predicate void ops"); + llvm::report_fatal_error("Fatal pipeliner error"); + } + rewriter.setInsertionPoint(op); + auto ifOp = scf::IfOp::create(rewriter, op->getLoc(), pred, false); + rewriter.setInsertionPointToStart(ifOp.thenBlock()); + Operation *cloned = rewriter.clone(*op); + rewriter.eraseOp(op); + return cloned; +} + +static Operation *musaPredicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isConstantIntValue(pred, 1)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (op->hasTrait()) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto addTransOp = dyn_cast(op)) { + rewriter.setInsertionPoint(addTransOp); + Value mask = getPredMask(rewriter, addTransOp.getPred().getType(), + addTransOp.getPred(), pred); + addTransOp.getPredMutable().assign(mask); + return op; + } + if (auto arriveOp = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveOp); + Value mask = getPredMask(rewriter, arriveOp.getPred().getType(), + arriveOp.getPred(), pred); + arriveOp.getPredMutable().assign(mask); + return op; + } + if (isa( + op)) { + return predicateVoidOpWithIf(rewriter, op, pred); + } + if (!op->isRegistered()) + return op; + + op->emitOpError("MUSA pipeliner doesn't know how to predicate this op."); + llvm::report_fatal_error("Fatal pipeliner error"); +} + +static Operation *musaWrapInMaskOp(RewriterBase &rewriter, Operation *op, + Value pred) { + auto mask = + ttg::MaskOp::create(rewriter, op->getLoc(), op->getResultTypes(), pred); + rewriter.createBlock(&mask->getRegion(0)); + rewriter.setInsertionPointToStart(&mask->getRegion(0).front()); + auto newOp = rewriter.clone(*op); + ttg::MaskReturnOp::create(rewriter, op->getLoc(), newOp->getResults()); + op->replaceAllUsesWith(mask->getResults()); + rewriter.eraseOp(op); + return mask; +} + +static void musaResolveMaskOp(ModuleOp moduleOp) { + IRRewriter rewriter(moduleOp); + + auto arithDialect = + moduleOp.getContext()->getLoadedDialect(); + RewritePatternSet patterns(moduleOp.getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsGreedily(moduleOp, std::move(patterns)).failed()) + llvm::report_fatal_error("Failed to canonicalize the IR"); + + SmallVector maskOps; + moduleOp.walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); }); + for (ttg::MaskOp maskOp : maskOps) { + rewriter.setInsertionPoint(maskOp); + while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) { + Operation *op = &maskOp.getBody()->front(); + rewriter.moveOpBefore(op, maskOp); + (void)musaPredicateOp(rewriter, op, maskOp.getPred()); + } + maskOp->replaceAllUsesWith( + maskOp.getBody()->getTerminator()->getOperands()); + maskOp->erase(); + } +} + +// ========================= +// Lower Loads / Descriptors +// ========================= + +static bool mustLoadToRegisters(Operation *op) { + if (auto loadOp = dyn_cast(op)) { + if (loadOp.getOther() && !isZeroConst(loadOp.getOther())) + return true; + } + + if (auto descLoad = dyn_cast(op)) { + return failed( + triton::musa::resolveDescriptorLoadLandingMemDescType(descLoad)); + } + + if (!op->hasOneUse()) + return true; + auto alloc = dyn_cast(*op->getUsers().begin()); + if (!alloc) + return true; + + Attribute loadEncoding; + if (auto descGather = dyn_cast(op)) { + loadEncoding = ttng::getEncodingFromDescriptor(op, descGather.getType(), + descGather.getDesc()); + } + return loadEncoding && (loadEncoding != alloc.getType().getEncoding()); +} + +static ttg::MemDescType +getGenericMusaLoadMemDescType(Operation *loadOp, ttg::SharedEncodingTrait enc) { + auto tensorTy = cast(loadOp->getResultTypes().front()); + auto sharedSpace = ttg::SharedMemorySpaceAttr::get(loadOp->getContext()); + return ttg::MemDescType::get(tensorTy.getShape(), tensorTy.getElementType(), + enc, sharedSpace, /*mutableMemory=*/true); +} + +static ttg::MemDescType getMusaMultiBufferedType(ttg::MemDescType viewTy, + int32_t depth) { + SmallVector shape(viewTy.getShape().begin(), + viewTy.getShape().end()); + SmallVector allocShape(viewTy.getAllocShape().begin(), + viewTy.getAllocShape().end()); + shape.insert(shape.begin(), depth); + allocShape.insert(allocShape.begin(), depth); + return ttg::MemDescType::get(shape, viewTy.getElementType(), + viewTy.getEncoding(), viewTy.getMemorySpace(), + /*mutableMemory=*/true, allocShape); +} + +static Value createMusaDescriptorAlloc(scf::ForOp forOp, Location loc, + ttg::MemDescType landingTy, + unsigned distance) { + OpBuilder builder(forOp); + auto allocTy = getMusaMultiBufferedType(landingTy, distance); + Value alloc = ttg::LocalAllocOp::create(builder, loc, allocTy); + builder.setInsertionPointAfter(forOp); + ttg::LocalDeallocOp::create(builder, loc, alloc); + return alloc; +} + +static TypedValue +createMusaSingleBufferView(OpBuilder &builder, Value alloc, Value idx) { + auto allocTy = cast(alloc.getType()); + SmallVector viewShape(allocTy.getShape().drop_front().begin(), + allocTy.getShape().drop_front().end()); + SmallVector viewAllocShape( + allocTy.getAllocShape().drop_front().begin(), + allocTy.getAllocShape().drop_front().end()); + auto viewTy = ttg::MemDescType::get( + viewShape, allocTy.getElementType(), allocTy.getEncoding(), + allocTy.getMemorySpace(), allocTy.getMutableMemory(), viewAllocShape); + return ttg::MemDescIndexOp::create(builder, alloc.getLoc(), viewTy, alloc, + idx); +} + +static bool hasMusaSqmmaRoot(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->without_terminator(), [](Operation &op) { + return isa(op); + }); +} + +static bool hasMusaDescriptorLoadRoot(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->without_terminator(), [](Operation &op) { + return isa(op); + }); +} + +static llvm::MapVector> +loadOpsToMusaSqmmaIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + int numStages) { + llvm::MapVector> loadOpToIndLevel; + DenseSet seen; + DenseSet excluded; + + std::function dfs = + [&](Operation *op, Operation *finalUser, int distance) { + if (!seen.insert(op).second || excluded.count(op)) + return; + if (isa(op)) { + if (loadOpToIndLevel.count(op)) { + int level = loadOpToIndLevel[op].first; + if (level != distance) { + loadOpToIndLevel.erase(op); + excluded.insert(op); + return; + } + } else { + loadOpToIndLevel[op] = {distance, finalUser}; + } + finalUser = op; + distance++; + } + + for (Value operand : getNestedOperands(op)) { + if (auto dotOp = dyn_cast(op)) { + if (operand == dotOp->getOperand(2)) + continue; + } + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) + dfs(defOp, finalUser, distance); + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, &op, 0); + } + + if (pipelineWithoutDot) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + dfs(&op, &op, 0); + } + } + + for (auto iter = loadOpToIndLevel.begin(); iter != loadOpToIndLevel.end();) { + if (iter->second.first >= numStages - 1) + iter = loadOpToIndLevel.erase(iter); + else + ++iter; + } + + return loadOpToIndLevel; +} + +static tt::CoarseSchedule::Cluster scheduleMusaSqmmaPrologueAndEpilogue( + scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + DenseMap ifsToStage; + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + llvm::SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.omitUsesFromAbove = false; + (void)getBackwardSlice(op, &backwardSlice, opt); + for (Operation *sliceOp : backwardSlice) { + if (auto ifOp = dyn_cast(sliceOp)) + ifsToStage.insert({ifOp, stage}); + } + } + } + if (!ifsToStage.empty()) { + tt::CoarseSchedule::Cluster prologueCluster = + schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) + schedule.insert(ifOp, stage, prologueCluster); + } + + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (Operation *rootUser : rootUsers) { + llvm::SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + int stage = schedule[rootUser].first; + for (Operation *sliceOp : forwardSlice) { + scf::IfOp ifOp = dyn_cast(sliceOp); + if (!ifOp) { + Operation *parentOp = sliceOp->getParentOp(); + if (parentOp && parentOp->getParentOp() == forOp.getOperation()) + ifOp = dyn_cast(parentOp); + } + if (ifOp) + schedule.insertIfAbsent(ifOp, stage, epilogueCluster); + } + } + + return afterPrologue; +} + +static void scheduleMusaSqmmaDependencies(scf::ForOp forOp, + tt::CoarseSchedule &schedule, + int numStages) { + auto opsInOrder = schedule.getOpsInOrder(forOp); + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, /*includeArg=*/false); + } + } +} + +static void scheduleMusaSqmmaDistanceOneDependencies( + scf::ForOp forOp, tt::CoarseSchedule &schedule, int numStages) { + DenseMap + dist1Cluster; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value yielded = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = yielded.getDefiningOp(); + if (!defOp || schedule.count(defOp)) + continue; + if (isa(defOp)) { + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, /*includeArg=*/true, + /*insertIfEarlier=*/true); + continue; + } + auto clusterHash = tt::CoarseSchedule::hashCluster(cluster); + if (!dist1Cluster.count(clusterHash)) + dist1Cluster[clusterHash] = schedule.clusters.newBefore(cluster); + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[clusterHash]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[clusterHash], + /*includeArg=*/true, + /*insertIfEarlier=*/true); + } + } +} + +static void scheduleMusaSqmmaRemainingToLastStage( + scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, int numStages) { + DenseMap opToCluster; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + opToCluster[&op] = afterPrologue; + } + + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage == numStages - 1) + queue.push_back(op); + } + + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (Operation *user : op->getUsers()) { + if (!opToCluster.count(user)) + continue; + auto userCluster = opToCluster[user]; + auto opCluster = + schedule.count(op) ? schedule[op].second : opToCluster[op]; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + + for (auto [op, cluster] : opToCluster) + schedule.insert(op, numStages - 1, cluster); +} + +static FailureOr +synthesizeMusaSqmmaSchedule(scf::ForOp forOp, int defaultNumStages) { + if (!triton::gpu::isSafeToPipeline(forOp)) + return failure(); + if (!hasMusaSqmmaRoot(forOp)) + return failure(); + + int numStages = tt::getNumStagesOrDefault(forOp, defaultNumStages); + if (numStages <= 1) + return failure(); + + auto loadOpToIndLevel = loadOpsToMusaSqmmaIndirectionLevel( + forOp, forOp->hasAttr(tt::kNumStagesAttrName), numStages); + if (loadOpToIndLevel.empty()) + return failure(); + + tt::CoarseSchedule schedule(numStages); + DenseSet rootUsers; + int maxIndirectionLevel = -1; + for (auto &[loadOp, info] : loadOpToIndLevel) + maxIndirectionLevel = std::max(maxIndirectionLevel, info.first); + if (maxIndirectionLevel < 0) + return failure(); + + auto rootUsersCluster = schedule.clusters.newAtFront(); + for (auto &[loadOp, info] : loadOpToIndLevel) { + Operation *use = info.second; + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + if (rootUsers.empty()) + return failure(); + + unsigned stagesBetweenLoads = 0; + if (numStages > 2) { + stagesBetweenLoads = + (static_cast(numStages - 2 + maxIndirectionLevel)) / + static_cast(maxIndirectionLevel + 1); + } + + SmallVector loadClusters; + for (int i = 0; i < maxIndirectionLevel + 1; ++i) + loadClusters.push_back(schedule.clusters.newAtBack()); + + for (auto &[loadOp, info] : loadOpToIndLevel) { + int stage = (maxIndirectionLevel - info.first) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadClusters[info.first]); + } + + auto afterPrologue = scheduleMusaSqmmaPrologueAndEpilogue( + forOp, schedule, rootUsers, numStages); + scheduleMusaSqmmaDependencies(forOp, schedule, numStages); + scheduleMusaSqmmaDistanceOneDependencies(forOp, schedule, numStages); + scheduleMusaSqmmaRemainingToLastStage(forOp, schedule, afterPrologue, + numStages); + return schedule; +} + +static int getDefUseStageDiff(Operation *op, scf::ForOp forOp, + tt::CoarseSchedule &schedule) { + assert(schedule.count(op) && "Op not found in the schedule"); + int defStage = schedule[op].first; + tt::CoarseSchedule::Cluster defCluster = schedule[op].second; + std::optional useStage; + DenseSet topLevelUsers = + triton::getTopLevelUsersInLoop(op, forOp); + if (isa(op)) { + DenseSet allocUsers; + for (Operation *topLevelUser : topLevelUsers) { + if (auto localAlloc = dyn_cast(topLevelUser)) { + DenseSet users = + triton::getTopLevelUsersInLoop(localAlloc, forOp); + allocUsers.insert(users.begin(), users.end()); + } + } + topLevelUsers.insert(allocUsers.begin(), allocUsers.end()); + } + DenseSet topLevelWaitUsers; + for (Operation *topLevelUser : topLevelUsers) { + if (isa(topLevelUser)) + topLevelWaitUsers.insert(topLevelUser); + } + for (Operation *topLevelUser : topLevelUsers) { + if (!schedule.count(topLevelUser)) { + topLevelUser->emitOpError("top-level user missing from MUSA pipeline " + "schedule"); + llvm::report_fatal_error("Fatal pipeliner error"); + } + int curUseStage = schedule[topLevelUser].first; + tt::CoarseSchedule::Cluster curUseCluster = schedule[topLevelUser].second; + if (*curUseCluster > *defCluster) + curUseStage++; + useStage = std::min(curUseStage, useStage.value_or(curUseStage)); + } + for (Operation *topLevelUser : topLevelWaitUsers) { + int curUseStage = schedule[topLevelUser].first; + useStage = std::max(curUseStage, useStage.value_or(curUseStage)); + } + if (!useStage) + return 0; + assert(useStage >= defStage && "Op used before defined"); + return useStage.value() - defStage; +} + +static Value createAlloc(scf::ForOp forOp, Operation *loadOp, + ttg::MemDescType landingTy, unsigned distance) { + return createMusaDescriptorAlloc(forOp, loadOp->getLoc(), landingTy, + distance); +} + +static ttg::LocalAllocOp getSqmmaOperandLocalAllocUser(tt::LoadOp loadOp) { + if (mustLoadToRegisters(loadOp)) + return {}; + auto localAlloc = dyn_cast(*loadOp->getUsers().begin()); + if (!localAlloc || + !triton::musa::hasSqmmaOpIdxAttr(localAlloc.getOperation())) { + return {}; + } + return localAlloc; +} + +static ttg::MemDescType getSqmmaOperandLandingMemDescType(tt::LoadOp loadOp) { + auto localAlloc = getSqmmaOperandLocalAllocUser(loadOp); + if (!localAlloc) + return {}; + return dyn_cast(localAlloc.getType()); +} + +static void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, int contiguity, + tt::CoarseSchedule &schedule) { + tt::OpBuilderForStage builder(loadOp.getLoc(), forOp, schedule); + Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); + assert(firstUse && "LoadOp has no users"); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(loadOp); + builder.setStageCluster(schedule[loadOp]); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + + Value view = createMusaSingleBufferView(builder, alloc, insertIdx); + auto sqmmaLocalAlloc = getSqmmaOperandLocalAllocUser(loadOp); + if (sqmmaLocalAlloc) + triton::musa::copySqmmaAttrs(sqmmaLocalAlloc.getOperation(), + view.getDefiningOp()); + Operation *copy = ttg::AsyncCopyGlobalToLocalOp::create( + builder, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), contiguity); + Operation *commit = + ttg::AsyncCommitGroupOp::create(builder, copy->getResult(0)); + + builder.setStageCluster(schedule[firstUse]); + auto wait = ttg::AsyncWaitOp::create(builder, commit->getResult(0), 0); + auto viewLoad = createMusaSingleBufferView(builder, alloc, extractIdx); + if (sqmmaLocalAlloc) + triton::musa::copySqmmaAttrs(sqmmaLocalAlloc.getOperation(), + viewLoad.getDefiningOp()); + + if (sqmmaLocalAlloc) { + replaceUsesAndPropagateType(builder, sqmmaLocalAlloc, viewLoad); + sqmmaLocalAlloc.erase(); + } else if (!loadOp.getOther() || isZeroConst(loadOp.getOther())) { + replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad, + wait.getResult()); + } else if (loadOp->use_begin() != loadOp->use_end()) { + auto sharedLoad = ttg::LocalLoadOp::create(builder, loadOp.getType(), + viewLoad, wait.getResult()); + auto select = + arith::SelectOp::create(builder, loadOp.getType(), loadOp.getMask(), + sharedLoad.getResult(), other); + loadOp->replaceAllUsesWith(select->getResults()); + } + schedule.erase(loadOp); + loadOp->erase(); +} + +static void replaceDescriptorUsesWithMemDescOrLocalLoad( + tt::DescriptorLoadOp loadOp, Value sourceMemDesc, RewriterBase &rewriter) { + SmallVector users(loadOp->getUsers().begin(), + loadOp->getUsers().end()); + Value tensorValue = loadOp.getResult(); + Value localLoadValue; + auto getLocalLoadValue = [&]() -> Value { + if (!localLoadValue) { + localLoadValue = ttg::LocalLoadOp::create(rewriter, loadOp.getLoc(), + loadOp.getType(), sourceMemDesc) + .getResult(); + } + return localLoadValue; + }; + + for (Operation *user : users) { + bool isTensorViewUser = isa(user); + if (triton::musa::tryReplaceTensorUserWithMemDesc(rewriter, tensorValue, + sourceMemDesc, user) && + !isTensorViewUser) + continue; + if (isTensorViewUser && user->use_empty()) { + rewriter.eraseOp(user); + continue; + } + rewriter.setInsertionPoint(user); + user->replaceUsesOfWith(tensorValue, getLocalLoadValue()); + if (isTensorViewUser && user->use_empty()) + rewriter.eraseOp(user); + } +} + +struct AsyncLoad { + int stageDiff; + int contiguity = 1; + Value alloc; + Value barrier; + Operation *waitOp = nullptr; + ttg::MemDescType landingMemDescTy; +}; + +struct LoadGroupInfo { + Value insertIdx; + Value extractIdx; + Value phase; + Value yieldPhase; + bool hasTMALoad = false; + int32_t barrierBase = 0; +}; + +static void scheduleScalarPipelineValue(tt::CoarseSchedule &schedule, Value v, + int stage, + tt::CoarseSchedule::Cluster cluster) { + Operation *op = v.getDefiningOp(); + if (!op) + return; + schedule.insert(op, stage, cluster); + schedule.insertDepsOfOp(op, stage, cluster, /*includeArg=*/false, + /*insertIfEarlier=*/true); +} + +static void collectNonViewUsers(Operation *op, + SmallVectorImpl &out) { + for (Operation *user : op->getUsers()) { + if (user->hasTrait()) { + collectNonViewUsers(user, out); + continue; + } + out.push_back(user); + } +} + +static LogicalResult +validateMusaScheduleClusters(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + DenseSet validClusters; + for (auto it = schedule.clusters.begin(); it != schedule.clusters.end(); ++it) + validClusters.insert(tt::CoarseSchedule::hashCluster(it)); + + for (Operation &op : forOp.getBody()->without_terminator()) { + auto it = schedule.find(&op); + if (it == schedule.end()) + continue; + auto [stage, cluster] = it->second; + auto clusterHash = tt::CoarseSchedule::hashCluster(cluster); + if (validClusters.count(clusterHash)) + continue; + op.emitError() << "scheduled into foreign cluster hash=" << clusterHash + << " at stage " << stage; + return failure(); + } + return success(); +} + +static Value materializeBarrierIdValue(OpBuilder &builder, Location loc, + int32_t base, Value slotIdx) { + Value baseVal = arith::ConstantIntOp::create(builder, loc, base, 32); + if (auto constIdx = slotIdx.getDefiningOp()) { + auto intAttr = dyn_cast(constIdx.getValueAttr()); + if (intAttr) + return arith::ConstantIntOp::create(builder, loc, base + intAttr.getInt(), + 32); + } + return arith::AddIOp::create(builder, loc, baseVal, slotIdx); +} + +static void convertScalarToTensorLoad(Operation *op, + tt::CoarseSchedule &schedule, + scf::ForOp forOp) { + auto scalarLoad = cast(op); + Type scalarTy = scalarLoad.getType(); + tt::OpBuilderForStage builder(op->getLoc(), op, schedule); + builder.setInsertionPoint(op); + MLIRContext *ctx = op->getContext(); + auto nWarps = ttg::lookupNumWarps(op->getParentRegion()); + ModuleOp mod = forOp->getParentOfType(); + auto threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + auto numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + auto blockedEnc = + ttg::getDefaultBlockedEncoding(ctx, {1}, nWarps, threadsPerWarp, numCTAs); + auto newPtrTy = + RankedTensorType::get({1}, scalarLoad.getPtr().getType(), blockedEnc); + auto newPtr = + tt::SplatOp::create(builder, op->getLoc(), newPtrTy, scalarLoad.getPtr()); + scalarLoad.getPtrMutable().assign(newPtr); + if (scalarLoad.getMask()) { + auto newMaskTy = + RankedTensorType::get({1}, scalarLoad.getMask().getType(), blockedEnc); + auto newMask = tt::SplatOp::create(builder, op->getLoc(), newMaskTy, + scalarLoad.getMask()); + scalarLoad.getMaskMutable().assign(newMask); + } + if (scalarLoad.getOther()) { + auto newOtherTy = + RankedTensorType::get({1}, scalarLoad.getOther().getType(), blockedEnc); + auto newOther = tt::SplatOp::create(builder, op->getLoc(), newOtherTy, + scalarLoad.getOther()); + scalarLoad.getOtherMutable().assign(newOther); + } + auto newDstTy = RankedTensorType::get({1}, scalarTy, blockedEnc); + scalarLoad.getResult().setType(newDstTy); + builder.setInsertionPointAfter(op); + Operation *firstUse = getFirstUseOfPipelinedOp({op}, forOp, schedule); + builder.setStageCluster(schedule[firstUse]); + Operation *unsplat = tt::UnsplatOp::create(builder, op->getLoc(), scalarTy, + scalarLoad.getResult()); + scalarLoad.getResult().replaceAllUsesExcept(unsplat->getResult(0), unsplat); +} + +static void +createMUSATMABarrierAndWait(scf::ForOp forOp, + llvm::MapVector &asyncLoads, + llvm::MapVector &loadGroups, + tt::CoarseSchedule &schedule) { + OpBuilder initBuilder(forOp); + initBuilder.setInsertionPoint(forOp); + Location loopLoc = forOp.getLoc(); + Value zeroPhase = arith::ConstantIntOp::create(initBuilder, loopLoc, 0, 32); + Value arriveCount = arith::ConstantIntOp::create(initBuilder, loopLoc, 1, 32); + + SmallVector> commonWaitGroups; + llvm::SmallDenseSet visited; + for (auto &[loadOp, asyncLoad] : asyncLoads) { + if (!tt::isTMALoad(loadOp) || visited.count(loadOp)) + continue; + llvm::SmallDenseSet users; + SmallVector group; + Block *loadBlock = loadOp->getBlock(); + auto addToGroup = [&](Operation *groupLoadOp) { + group.push_back(groupLoadOp); + visited.insert(groupLoadOp); + bool sharedFirst = !mustLoadToRegisters(groupLoadOp); + for (Operation *user : groupLoadOp->getUsers()) { + if (sharedFirst) { + auto alloc = dyn_cast(user); + if (alloc && alloc->getBlock() == loadBlock) { + SmallVector nonViewUsers; + collectNonViewUsers(alloc, nonViewUsers); + for (Operation *nonViewUser : nonViewUsers) { + Operation *userInBlock = + loadBlock->findAncestorOpInBlock(*nonViewUser); + if (userInBlock) + users.insert(userInBlock); + } + continue; + } + } + Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user); + if (userInBlock) + users.insert(userInBlock); + } + }; + addToGroup(loadOp); + Operation *nextOp = loadOp->getNextNode(); + int numBuffers = asyncLoad.stageDiff; + while (nextOp) { + if (users.count(nextOp) || visited.count(nextOp)) + break; + if (tt::isTMALoad(nextOp) && asyncLoads.count(nextOp)) { + if (asyncLoads[nextOp].stageDiff != numBuffers) + break; + if (group.size() > 0 && schedule[group[0]] == schedule[nextOp]) { + addToGroup(nextOp); + } + } + nextOp = nextOp->getNextNode(); + } + commonWaitGroups.push_back(group); + } + + for (SmallVector &group : commonWaitGroups) { + int64_t sizeInBytes = 0; + int numBuffers = asyncLoads[group[0]].stageDiff; + auto reserved = triton::musa::reserveBarrierIdRange(forOp, numBuffers); + if (failed(reserved)) { + forOp.emitOpError("unable to reserve MUSA async barrier ids for " + "pipelined TME load group"); + llvm::report_fatal_error("Fatal pipeliner error"); + } + int32_t barrierBase = *reserved; + for (int32_t slot = 0; slot < numBuffers; ++slot) { + Value barId = arith::ConstantIntOp::create(initBuilder, loopLoc, + barrierBase + slot, 32); + triton::musa::InitArrivalOp::create(initBuilder, loopLoc, barId, + arriveCount, zeroPhase); + } + LoadGroupInfo &loadGroup = loadGroups.find(numBuffers)->second; + for (Operation *op : group) { + auto tensorTy = cast(op->getResultTypes()[0]); + int64_t loadSize = product(ttg::getShapePerCTA(tensorTy)); + sizeInBytes += loadSize * tensorTy.getElementTypeBitWidth() / 8; + } + + tt::OpBuilderForStage builder(forOp.getLoc(), group[0], schedule); + builder.setInsertionPoint(group[0]); + builder.setStageCluster(schedule[group[0]]); + Value issueBarId = materializeBarrierIdValue( + builder, group[0]->getLoc(), barrierBase, loadGroup.insertIdx); + Value transBytes = arith::ConstantIntOp::create(builder, group[0]->getLoc(), + sizeInBytes, 32); + Value pred = + arith::ConstantIntOp::create(builder, group[0]->getLoc(), 1, 1); + auto addTrans = triton::musa::BarrierAddTransOp::create( + builder, group[0]->getLoc(), issueBarId, transBytes, pred); + + builder.setInsertionPointAfter(group.back()); + builder.setStageCluster(schedule[group.back()]); + auto arrive = triton::musa::ArriveBarrierNoRetOp::create( + builder, group.back()->getLoc(), issueBarId, pred); + + Operation *firstUse = getFirstUseOfPipelinedOp(group, forOp, schedule); + builder.setInsertionPointAfter(arrive); + builder.setStageCluster(schedule[firstUse]); + Value waitBarId = materializeBarrierIdValue( + builder, firstUse->getLoc(), barrierBase, loadGroup.extractIdx); + Value waitPhase = + loadGroup.yieldPhase ? loadGroup.yieldPhase : loadGroup.phase; + auto wait = triton::musa::WaitBarrierOp::create(builder, firstUse->getLoc(), + waitBarId, waitPhase); + + for (Operation *op : group) { + asyncLoads[op].barrier = issueBarId; + asyncLoads[op].waitOp = wait; + } + } +} + +static bool loadRequiresAdditionalBuffer(Operation *loadOp) { + auto isMusaTarget = [&]() { + auto module = loadOp->getParentOfType(); + if (!module) + return false; + auto targetAttr = module->getAttrOfType(ttg::AttrTargetName); + return targetAttr && targetAttr.getValue().starts_with("musa:"); + }; + + std::function &)> collectNonViewUsers = + [&](Value value, SmallVector &out) { + for (Operation *user : value.getUsers()) { + if (user->hasTrait()) { + for (Value result : user->getResults()) + if (isa(result.getType())) + collectNonViewUsers(result, out); + } else { + out.push_back(user); + } + } + }; + std::function &)> hasDotConsumer = + [&](Operation *op, DenseSet &visited) -> bool { + if (!visited.insert(op).second) + return false; + if (isa(op)) + return true; + for (Operation *user : op->getUsers()) { + if (hasDotConsumer(user, visited)) + return true; + } + return false; + }; + if (!mustLoadToRegisters(loadOp)) { + SmallVector landingMemDescs; + if (auto descLoad = dyn_cast(loadOp)) { + if (failed( + triton::musa::resolveDescriptorLoadLandingMemDescType(descLoad))) + return false; + landingMemDescs = + triton::musa::collectCanonicalLandingMemDescRoots(descLoad); + } else if (loadOp->hasOneUse()) { + if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) + landingMemDescs.push_back(alloc.getResult()); + } + for (Value memDesc : landingMemDescs) { + SmallVector nonViewUsers; + collectNonViewUsers(memDesc, nonViewUsers); + if (llvm::any_of(nonViewUsers, [&](Operation *op) { + if (isa(op)) + return true; + if (isMusaTarget()) { + DenseSet visited; + return hasDotConsumer(op, visited); + } + return false; + })) { + return true; + } + } + } + return false; +} + +static FailureOr +lowerLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + auto module = forOp->getParentOfType(); + auto targetAttr = module + ? module->getAttrOfType(ttg::AttrTargetName) + : StringAttr(); + bool isMusaTarget = targetAttr && targetAttr.getValue().starts_with("musa:"); + auto hasDotConsumer = [&](Operation *loadOp) { + DenseSet visited; + std::function dfs = [&](Operation *op) -> bool { + if (!visited.insert(op).second) + return false; + if (isa(op)) + return true; + for (Operation *user : op->getUsers()) { + if (dfs(user)) + return true; + } + return false; + }; + for (Value result : loadOp->getResults()) { + for (Operation *user : result.getUsers()) { + if (dfs(user)) + return true; + } + } + return false; + }; + + llvm::MapVector asyncLoads; + llvm::MapVector loadGroups; + llvm::SmallVector scalarLoads; + for (auto &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + + if (isa(op)) { + op.emitOpError("pipelined descriptor_gather is not supported on MUSA"); + return failure(); + } + + int stageDiff = getDefUseStageDiff(&op, forOp, schedule); + if (stageDiff == 0) + continue; + + ttg::SharedEncodingTrait sharedEncoding; + ttg::MemDescType landingMemDescTy; + bool canUseAsyncCp = false; + int contiguity = 1; + if (!isa(op.getResultTypes()[0])) { + canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32; + auto numCTAs = ttg::lookupNumCTAs(forOp); + sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( + forOp.getContext(), 1, 1, 1, {0}, + ttg::CGAEncodingAttr::get1DLayout(forOp.getContext(), numCTAs)); + auto sharedSpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + landingMemDescTy = ttg::MemDescType::get( + ArrayRef({1}), op.getResultTypes()[0], sharedEncoding, + sharedSpace, /*mutableMemory=*/true); + if (canUseAsyncCp) + scalarLoads.push_back(&op); + } else { + if (isMusaTarget && isa(op)) { + auto resolvedLandingTy = + triton::musa::resolveDescriptorLoadLandingMemDescType( + cast(&op)); + if (failed(resolvedLandingTy)) { + op.emitOpError("pipelined descriptor load requires normalized " + "canonical landing memdesc encoding"); + return failure(); + } + landingMemDescTy = *resolvedLandingTy; + sharedEncoding = + dyn_cast(landingMemDescTy.getEncoding()); + } else { + sharedEncoding = tt::getSharedEncoding(&op); + landingMemDescTy = getGenericMusaLoadMemDescType(&op, sharedEncoding); + } + canUseAsyncCp = + isa(op) && + canBeConvertedToAsyncLoad(cast(op), axisInfoAnalysis); + int copyVecBytes = tt::getCopyVecBytes( + cast(op.getResultTypes()[0]), sharedEncoding); + canUseAsyncCp &= copyVecBytes >= 4; + if (canUseAsyncCp) { + auto loadOp = cast(op); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + contiguity = vec; + } + } + + if (canUseAsyncCp || tt::isTMALoad(&op)) { + if (loadRequiresAdditionalBuffer(&op)) + stageDiff += 1; + if (isMusaTarget && hasDotConsumer(&op)) { + if (tt::isTMALoad(&op)) + stageDiff += 1; + if (auto descLoad = dyn_cast(&op)) { + if (auto sqmmaLandingTy = + triton::musa::getUniqueCanonicalLandingSqmmaMemDescType( + descLoad)) { + landingMemDescTy = *sqmmaLandingTy; + sharedEncoding = dyn_cast( + landingMemDescTy.getEncoding()); + } + } + if (isa(op)) { + if (auto sqmmaLandingTy = + getSqmmaOperandLandingMemDescType(cast(&op))) { + stageDiff += 1; + landingMemDescTy = sqmmaLandingTy; + sharedEncoding = dyn_cast( + landingMemDescTy.getEncoding()); + } + } + } + auto &asyncLoad = asyncLoads[&op]; + asyncLoad.stageDiff = stageDiff; + asyncLoad.contiguity = contiguity; + asyncLoad.landingMemDescTy = landingMemDescTy; + } else if (stageDiff > 1) { + op.emitRemark() << "Pipelining load that cannot use vectorized copy. " + "This will likely lead to pipelining in registers and " + "severe performance degradation."; + } + } + for (Operation *op : scalarLoads) + convertScalarToTensorLoad(op, schedule, forOp); + + if (asyncLoads.empty()) + return forOp; + + for (auto &[loadOp, asyncLoad] : asyncLoads) { + Value alloc = createAlloc(forOp, loadOp, asyncLoad.landingMemDescTy, + asyncLoad.stageDiff); + asyncLoad.alloc = alloc; + loadGroups.insert({asyncLoad.stageDiff, {}}); + if (tt::isTMALoad(loadOp)) + loadGroups[asyncLoad.stageDiff].hasTMALoad = true; + } + IRRewriter builder(forOp); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + Value minusOne = arith::ConstantIntOp::create(builder, loc, -1, 32); + Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); + Value one = arith::ConstantIntOp::create(builder, loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + for (auto [numBuffers, loadGroup] : loadGroups) { + Value initCounter = minusOne; + newOperands.push_back(initCounter); + newOperands.push_back(initCounter); + if (loadGroup.hasTMALoad) + newOperands.push_back(zero); + } + + forOp = addIterArgsToLoop(builder, forOp, newOperands); + + auto forYield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) + forYield.getResultsMutable().append(newOperands[i]); + + builder.setInsertionPoint(forOp); + loc = forOp.getLoc(); + int argIdx = newOperandIndex; + auto findGroupIssueAnchor = [&](int numBuffers) -> Operation * { + for (auto &[op, asyncLoad] : asyncLoads) + if (asyncLoad.stageDiff == numBuffers) + return op; + return nullptr; + }; + auto findGroupConsumerAnchor = [&](int numBuffers) -> Operation * { + SmallVector group; + for (auto &[op, asyncLoad] : asyncLoads) + if (asyncLoad.stageDiff == numBuffers) + group.push_back(op); + if (group.empty()) + return nullptr; + return getFirstUseOfPipelinedOp(group, forOp, schedule); + }; + for (auto &[numBuffers, loadGroup] : loadGroups) { + Value insertIdx = forOp.getBody()->getArgument(argIdx++); + Value extractIdx = forOp.getBody()->getArgument(argIdx++); + Value phase; + if (loadGroup.hasTMALoad) + phase = forOp.getBody()->getArgument(argIdx++); + loadGroup.phase = phase; + + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + + Value numBuffersVal = + arith::ConstantIntOp::create(builder, loc, numBuffers, 32); + loadGroup.insertIdx = tt::createIncrementModulo(builder, loc, insertIdx, + numBuffersVal, zero, one); + Value cndExt; + loadGroup.extractIdx = tt::createIncrementModulo( + builder, loc, extractIdx, numBuffersVal, zero, one, &cndExt); + if (phase) { + Value nextPhase = arith::XOrIOp::create(builder, loc, phase, one); + loadGroup.yieldPhase = + arith::SelectOp::create(builder, loc, cndExt, nextPhase, phase); + } + + (void)findGroupIssueAnchor(numBuffers); + Operation *consumerAnchor = findGroupConsumerAnchor(numBuffers); + if (consumerAnchor && loadGroup.yieldPhase) { + auto [consumerStage, consumerCluster] = schedule[consumerAnchor]; + scheduleScalarPipelineValue(schedule, loadGroup.yieldPhase, consumerStage, + consumerCluster); + } + } + + createMUSATMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule); + + bool hasAsyncLoads = false; + for (auto &[op, asyncLoad] : asyncLoads) { + LoadGroupInfo &loadGroup = loadGroups[asyncLoad.stageDiff]; + Value insertIdx = loadGroup.insertIdx; + Value extractIdx = loadGroup.extractIdx; + if (auto loadOp = dyn_cast(op)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + asyncLoad.contiguity, schedule); + hasAsyncLoads = true; + } else if (auto loadOp = dyn_cast(op)) { + tt::OpBuilderForStage copyBuilder(loadOp.getLoc(), forOp, schedule); + copyBuilder.setInsertionPoint(loadOp); + copyBuilder.setStageCluster(schedule[loadOp]); + auto view = + createMusaSingleBufferView(copyBuilder, asyncLoad.alloc, insertIdx); + if (auto allocOp = asyncLoad.alloc.getDefiningOp()) + triton::musa::copyCanonicalLandingSqmmaAttrs(loadOp, allocOp); + if (auto viewOp = view.getDefiningOp()) + triton::musa::copyCanonicalLandingSqmmaAttrs(loadOp, viewOp); + Value pred = + arith::ConstantIntOp::create(copyBuilder, loadOp.getLoc(), 1, 1); + auto blockTy = loadOp.getDesc().getType().getSignlessBlockType(); + auto coord = triton::musa::materializeTMECoordValues( + loadOp.getLoc(), loadOp.getIndices(), copyBuilder); + if (failed(coord)) { + loadOp.emitOpError("unable to materialize pipelined TME block info"); + return failure(); + } + auto config = triton::musa::resolveFinalTMECopyConfig( + cast(view.getType()), blockTy.getShape(), + triton::musa::TMECopyKind::GlobalToLocal); + if (failed(config)) { + loadOp.emitOpError("unable to resolve pipelined TME load config"); + return failure(); + } + triton::musa::createAsyncTMECopyGlobalToLocal( + copyBuilder, loadOp.getLoc(), loadOp.getDesc(), *coord, + asyncLoad.barrier, view, pred, *config); + + copyBuilder.setInsertionPointAfter(asyncLoad.waitOp); + copyBuilder.setStageCluster( + schedule[getFirstUseOfPipelinedOp({loadOp}, forOp, schedule)]); + auto viewLoad = + createMusaSingleBufferView(copyBuilder, asyncLoad.alloc, extractIdx); + if (auto viewLoadOp = viewLoad.getDefiningOp()) + triton::musa::copyCanonicalLandingSqmmaAttrs(loadOp, viewLoadOp); + IRRewriter rewriter(loadOp.getContext(), ©Builder); + rewriter.setInsertionPoint(copyBuilder.getInsertionBlock(), + copyBuilder.getInsertionPoint()); + replaceDescriptorUsesWithMemDescOrLocalLoad(loadOp, viewLoad, rewriter); + schedule.erase(loadOp); + loadOp->erase(); + } + } + + argIdx = newOperandIndex - 1; + for (auto &[numBuffers, loadGroup] : loadGroups) { + forYield.setOperand(argIdx++, loadGroup.insertIdx); + forYield.setOperand(argIdx++, loadGroup.extractIdx); + if (loadGroup.phase) + forYield.setOperand(argIdx++, loadGroup.yieldPhase ? loadGroup.yieldPhase + : loadGroup.phase); + } + + if (failed(validateMusaScheduleClusters(forOp, schedule))) { + forOp.emitOpError("invalid cluster mapping before scheduling MUSA pipeline " + "dependencies"); + return failure(); + } + scheduleDependencies(forOp, schedule); + + if (hasAsyncLoads) { + builder.setInsertionPointAfter(forOp); + ttg::AsyncWaitOp::create(builder, loc, ValueRange({}), 0); + } + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!schedule.count(&op)) { + op.emitError() << "op not found in the schedule"; + return failure(); + } + } + return forOp; +} + +static LogicalResult musaLowerLoops(ModuleOp moduleOp, int defaultNumStages) { + triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + SmallVector loops; + moduleOp.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + auto lowerLoopWithSchedule = + [&](scf::ForOp forOp, tt::CoarseSchedule &schedule) -> LogicalResult { + auto lowered = lowerLoads(forOp, schedule, axisInfoAnalysis); + if (failed(lowered)) + return failure(); + scf::ForOp newForOp = + triton::musa::pipeline::lowerTMADescriptors(*lowered, schedule); + schedule.serialize(newForOp); + return success(); + }; + + for (scf::ForOp forOp : loops) { + tt::CoarseSchedule schedule; + if (succeeded(schedule.deSerialize(forOp))) { + if (failed(lowerLoopWithSchedule(forOp, schedule))) + return failure(); + continue; + } + + { + auto synthesized = synthesizeMusaSqmmaSchedule(forOp, defaultNumStages); + if (failed(synthesized)) + continue; + if (failed(lowerLoopWithSchedule(forOp, *synthesized))) + return failure(); + } + } + return success(); +} + +static void expandLoops(ModuleOp moduleOp) { + SmallVector loops; + moduleOp.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + tt::CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp))) + continue; + + std::vector> finalSchedule = + schedule.createFinalSchedule(forOp); + tt::PipeliningOption options; + options.supportDynamicLoops = true; + options.peelEpilogue = false; + options.predicateFn = musaWrapInMaskOp; + options.getScheduleFn = + [&](scf::ForOp, + std::vector> &loopSchedule) { + loopSchedule = finalSchedule; + }; + + bool keepPredicateStage = forOp->hasAttr("__test_keep_predicate_stage"); + if (keepPredicateStage) { + options.emitPredicateStageFn = [](RewriterBase &rewriter, + Value inductionVar, Value upperBound, + Value step, uint64_t maxStage, + uint64_t stage) { + return ttg::PredicateStageOp::create(rewriter, inductionVar.getLoc(), + inductionVar, upperBound, step, + maxStage, stage); + }; + } + + IRRewriter rewriter(forOp); + if (failed(tt::pipelineForLoop(rewriter, forOp, options))) + continue; + } + + assert(moduleOp.getOps().empty() && + "PredicateStageOp should be resolved after the pipeline expansion"); + assert(verify(moduleOp).succeeded()); + musaResolveMaskOp(moduleOp); +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUPIPELINE +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUPipelinePass + : impl::TritonMUSAGPUPipelineBase { + using Base::Base; + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + + if (failed(musaLowerLoops(moduleOp, numStages))) + return signalPassFailure(); + if (dumpIntermediateSteps) { + llvm::dbgs() + << "// -----// TritonMUSAGPUPipeline internal IR Dump After: " + "LowerLoops\n" + << moduleOp << "\n\n\n"; + } + + expandLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() + << "// -----// TritonMUSAGPUPipeline internal IR Dump After: " + "ExpandLoops\n" + << moduleOp << "\n\n\n"; + } + + tt::removePipeliningAttributes(moduleOp); + triton::musa::pipeline::pipelineSqmma(moduleOp, numStages); + tt::updateWaits(moduleOp); + + auto *arithDialect = + moduleOp.getContext()->getLoadedDialect(); + RewritePatternSet patterns(moduleOp.getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsGreedily(moduleOp, std::move(patterns)).failed()) + return signalPassFailure(); + + SmallVector loops; + moduleOp.walk([&](scf::ForOp forOp) { + if (tt::getNumStagesOrDefault(forOp, numStages) > 1) + loops.push_back(forOp); + }); + for (scf::ForOp forOp : loops) { + auto pipelined = triton::musa::pipeline::pipelineTMEStores(forOp); + if (failed(pipelined)) + return signalPassFailure(); + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.cpp new file mode 100644 index 0000000000..7fe1205bd5 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.cpp @@ -0,0 +1,508 @@ +#include "SqmmaPipelineUtils.h" + +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MMAOperandUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir::triton::musa::pipeline { +namespace { + +using ProperlyAsyncDots = llvm::MapVector; + +bool isInCurrentPipelineLoop(Operation *op, scf::ForOp loop) { + return op->getParentOfType() == loop; +} + +Value getTransparentSqmmaOperandSource(Operation *op) { + if (auto cvt = dyn_cast(op)) + return cvt.getSrc(); + if (auto localLoad = dyn_cast(op)) + return localLoad.getSrc(); + if (auto index = dyn_cast(op)) + return index.getSrc(); + if (auto subslice = dyn_cast(op)) + return subslice.getSrc(); + if (auto trans = dyn_cast(op)) + return trans.getSrc(); + if (auto reshape = dyn_cast(op)) + return reshape.getSrc(); + if (auto reinterpret = dyn_cast(op)) + return reinterpret.getSrc(); + if (isNoop(op) && op->getNumOperands() == 1 && op->getNumResults() == 1) + return op->getOperand(0); + return Value(); +} + +struct ResolvedSqmmaValue { + Value root; + bool loopCarryCycle = false; +}; + +ResolvedSqmmaValue resolveSqmmaValueRoot(Value value, scf::ForOp loop) { + llvm::SmallPtrSet visited; + while (value) { + if (loop.isDefinedOutsideOfLoop(value)) + return {value, false}; + + if (!visited.insert(value.getAsOpaquePointer()).second) + return {value, true}; + + if (Operation *defOp = value.getDefiningOp()) { + if (Value source = getTransparentSqmmaOperandSource(defOp)) { + value = source; + continue; + } + return {value, false}; + } + + auto blockArg = dyn_cast(value); + if (!blockArg || blockArg.getOwner() != loop.getBody()) + return {value, false}; + + OpOperand *yielded = loop.getTiedLoopYieldedValue(blockArg); + if (!yielded) + return {value, false}; + value = yielded->get(); + } + return {}; +} + +bool sqmmaOperandCanBeProperlyAsync(Value operand, scf::ForOp loop) { + ResolvedSqmmaValue resolved = resolveSqmmaValueRoot(operand, loop); + return resolved.loopCarryCycle || + (resolved.root && loop.isDefinedOutsideOfLoop(resolved.root)); +} + +Value getWaitThreadableMemDesc(Value operand, scf::ForOp loop) { + ResolvedSqmmaValue resolved = resolveSqmmaValueRoot(operand, loop); + if (!resolved.root || !isa(resolved.root.getType())) + return Value(); + return resolved.root; +} + +bool hasLoopLocalSharedReuse(Value operand, scf::ForOp loop) { + ResolvedSqmmaValue resolved = resolveSqmmaValueRoot(operand, loop); + return resolved.root && isa(resolved.root.getType()) && + !resolved.loopCarryCycle && + !loop.isDefinedOutsideOfLoop(resolved.root); +} + +void collectDependentSqmmaDots(Value value, scf::ForOp loop, + llvm::SetVector &dots, + llvm::SmallPtrSetImpl &visited); + +void collectIfResultSqmmaDots(scf::IfOp ifOp, unsigned resultIdx, + scf::ForOp loop, + llvm::SetVector &dots, + llvm::SmallPtrSetImpl &visited) { + auto collectFromRegion = [&](Region ®ion) { + auto yieldOp = dyn_cast(region.front().getTerminator()); + if (!yieldOp || resultIdx >= yieldOp.getNumOperands()) + return; + collectDependentSqmmaDots(yieldOp.getOperand(resultIdx), loop, dots, + visited); + }; + collectFromRegion(ifOp.getThenRegion()); + if (ifOp.elseBlock()) + collectFromRegion(ifOp.getElseRegion()); +} + +void collectDependentSqmmaDots(Value value, scf::ForOp loop, + llvm::SetVector &dots, + llvm::SmallPtrSetImpl &visited) { + if (!value || !visited.insert(value.getAsOpaquePointer()).second) + return; + + if (loop.isDefinedOutsideOfLoop(value)) + return; + + if (auto blockArg = dyn_cast(value)) { + if (blockArg.getOwner() == loop.getBody()) { + if (OpOperand *yielded = loop.getTiedLoopYieldedValue(blockArg)) + collectDependentSqmmaDots(yielded->get(), loop, dots, visited); + } + return; + } + + auto result = dyn_cast(value); + if (!result) + return; + + Operation *defOp = result.getOwner(); + if (auto sqmma = dyn_cast(defOp)) { + if (!isInCurrentPipelineLoop(sqmma, loop)) + return; + dots.insert(sqmma); + collectDependentSqmmaDots(sqmma.getC(), loop, dots, visited); + return; + } + + if (auto wait = dyn_cast(defOp)) { + unsigned idx = result.getResultNumber(); + if (idx < wait.getInputs().size()) + collectDependentSqmmaDots(wait.getInputs()[idx], loop, dots, visited); + return; + } + + if (auto ifOp = dyn_cast(defOp)) { + collectIfResultSqmmaDots(ifOp, result.getResultNumber(), loop, dots, + visited); + return; + } + + if (auto forResult = dyn_cast(defOp)) { + if (forResult == loop && + result.getResultNumber() < loop.getNumRegionIterArgs()) { + collectDependentSqmmaDots(loop.getRegionIterArg(result.getResultNumber()), + loop, dots, visited); + } + return; + } + + if (Value source = getTransparentSqmmaOperandSource(defOp)) { + collectDependentSqmmaDots(source, loop, dots, visited); + return; + } +} + +void collectSqmmaMemDescDependencies(ArrayRef values, + Operation *waitAnchor, scf::ForOp loop, + llvm::SetVector &memDescs) { + llvm::SetVector dots; + llvm::SmallPtrSet visited; + for (Value value : values) + collectDependentSqmmaDots(value, loop, dots, visited); + + Operation *domRoot = waitAnchor->getParentOfType(); + if (!domRoot) + domRoot = waitAnchor->getParentOfType(); + DominanceInfo domInfo(domRoot ? domRoot : waitAnchor->getParentOp()); + + for (triton::musa::SquadDotOp sqmma : dots) { + for (Value operand : {sqmma.getA(), sqmma.getB()}) { + Value memDesc = getWaitThreadableMemDesc(operand, loop); + if (!memDesc || !domInfo.properlyDominates(memDesc, waitAnchor)) + continue; + memDescs.insert(memDesc); + } + } +} + +triton::musa::SquadDotWaitOp +threadValuesThroughWait(triton::musa::SquadDotWaitOp wait, + ArrayRef values, scf::ForOp loop) { + IRRewriter rewriter(wait.getContext()); + rewriter.setInsertionPoint(wait); + + const unsigned origNumOperands = wait.getNumOperands(); + llvm::SetVector operands(wait.getInputs().begin(), + wait.getInputs().end()); + operands.insert(values.begin(), values.end()); + + llvm::SetVector memDescs; + SmallVector dependencyRoots(operands.begin(), operands.end()); + collectSqmmaMemDescDependencies(dependencyRoots, wait, loop, memDescs); + operands.insert(memDescs.begin(), memDescs.end()); + + if (operands.size() == origNumOperands) + return wait; + + SmallVector newOperands(operands.begin(), operands.end()); + auto newWait = triton::musa::SquadDotWaitOp::create(rewriter, wait.getLoc(), + newOperands); + newWait->setAttrs(wait->getAttrs()); + + auto dominatedByNewWait = [&](OpOperand &operand) { + auto *topLevel = + newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); + return topLevel && newWait->isBeforeInBlock(topLevel); + }; + + for (unsigned idx = 0; idx < origNumOperands; ++idx) { + Value oldResult = wait.getResult(idx); + if (!isa(oldResult.getType())) + oldResult.replaceAllUsesWith(newWait.getResult(idx)); + } + for (unsigned idx = origNumOperands; idx < newOperands.size(); ++idx) { + Value operand = newWait.getOperand(idx); + if (!isa(operand.getType())) + operand.replaceUsesWithIf(newWait.getResult(idx), dominatedByNewWait); + } + + rewriter.eraseOp(wait); + return newWait; +} + +bool isTransitivelySqmmaCUse(OpOperand &use) { + Operation *user = use.getOwner(); + if (isa(user)) + return use.getOperandNumber() == 2; + if (isNoop(user) && user->getNumResults() == 1) + return llvm::all_of(user->getResult(0).getUses(), isTransitivelySqmmaCUse); + return false; +} + +std::optional dotCanBeProperlyAsync(triton::musa::SquadDotOp sqmma, + scf::ForOp loop) { + if (!sqmmaOperandCanBeProperlyAsync(sqmma.getA(), loop) || + !sqmmaOperandCanBeProperlyAsync(sqmma.getB(), loop)) { + return std::nullopt; + } + + if (auto cArg = dyn_cast(sqmma.getC())) { + if (cArg.getOwner() == loop.getBody() && cArg.getArgNumber() > 0) { + if (OpOperand *yielded = loop.getTiedLoopYieldedValue(cArg)) { + if (yielded->get() == sqmma.getResult() && + llvm::all_of(cArg.getUses(), isTransitivelySqmmaCUse)) { + return static_cast(cArg.getArgNumber() - 1); + } + } + } + } + + SmallVector> yieldedIterArgs; + auto recordYieldedIterArg = [&](int operandIdx) -> bool { + if (operandIdx >= loop.getNumRegionIterArgs()) + return false; + if (llvm::none_of(yieldedIterArgs, [&](const auto &entry) { + return entry.first == operandIdx; + })) { + yieldedIterArgs.push_back( + {operandIdx, loop.getRegionIterArg(operandIdx)}); + } + return true; + }; + + SmallVector> queue; + for (OpOperand &use : sqmma->getUses()) + queue.push_back({use.getOwner(), static_cast(use.getOperandNumber())}); + + while (!queue.empty()) { + auto [user, operandIdx] = queue.pop_back_val(); + if (user->getParentOp() == loop) { + if (isNoop(user) && user->getNumResults() == 1) { + for (OpOperand &use : user->getResult(0).getUses()) + queue.push_back( + {use.getOwner(), static_cast(use.getOperandNumber())}); + continue; + } + if (isa(user)) { + if (!recordYieldedIterArg(operandIdx)) + return std::nullopt; + continue; + } + return std::nullopt; + } + + auto ifOp = dyn_cast(user->getParentOp()); + if (!ifOp) + return std::nullopt; + + if (isa(user)) { + for (OpOperand &use : ifOp.getResult(operandIdx).getUses()) + queue.push_back( + {use.getOwner(), static_cast(use.getOperandNumber())}); + } + } + + if (yieldedIterArgs.empty()) + return std::nullopt; + + auto selectIterArg = + [&](auto &&predicate) -> std::optional> { + std::optional> selected; + for (const auto &entry : yieldedIterArgs) { + if (!predicate(entry.second)) + continue; + if (selected) + return std::nullopt; + selected = entry; + } + return selected; + }; + + auto otherYieldedIterArgsAreLoopDead = [&](int selectedIdx) { + return llvm::all_of(yieldedIterArgs, [&](const auto &entry) { + return entry.first == selectedIdx || entry.second.use_empty(); + }); + }; + + if (auto selected = selectIterArg([&](Value candidate) { + return candidate == sqmma.getC() && + llvm::all_of(candidate.getUses(), isTransitivelySqmmaCUse); + }); + selected && otherYieldedIterArgsAreLoopDead(selected->first)) { + return selected->first; + } + + if (auto selected = selectIterArg([&](Value candidate) { + return llvm::all_of(candidate.getUses(), isTransitivelySqmmaCUse); + }); + selected && otherYieldedIterArgsAreLoopDead(selected->first)) { + return selected->first; + } + + auto waitOps = loop.getBody()->getOps(); + auto firstWait = + llvm::find_if(waitOps, [](triton::musa::SquadDotWaitOp) { return true; }); + auto iterArgUsersAreAfterFirstWait = [&](Value candidate) { + return llvm::all_of(candidate.getUsers(), [&](Operation *user) { + assert(loop->isAncestor(user)); + while (user->getParentOp() != loop) + user = user->getParentOp(); + return (*firstWait)->isBeforeInBlock(user); + }); + }; + if (firstWait != waitOps.end()) { + auto selected = selectIterArg(iterArgUsersAreAfterFirstWait); + if (selected && otherYieldedIterArgsAreLoopDead(selected->first)) { + threadValuesThroughWait(*firstWait, {selected->second}, loop); + return selected->first; + } + } + + return std::nullopt; +} + +triton::musa::SquadDotWaitOp getOrCreateWaitBefore(Operation *op) { + if (auto wait = + dyn_cast_or_null(op->getPrevNode())) + return wait; + + OpBuilder builder(op); + return triton::musa::SquadDotWaitOp::create(builder, op->getLoc(), + ArrayRef{}); +} + +triton::musa::SquadDotWaitOp getOrCreateWaitAfter(Operation *op) { + if (auto wait = + dyn_cast_or_null(op->getNextNode())) + return wait; + + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + return triton::musa::SquadDotWaitOp::create(builder, op->getLoc(), + ArrayRef{}); +} + +void insertAsyncSqmmaWaitsInLoop(scf::ForOp loop, + const ProperlyAsyncDots &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + for (auto [asyncDotOp, iterArgIdx] : properlyAsyncDots) { + (void)iterArgIdx; + auto asyncDot = cast(asyncDotOp); + DenseMap> blockToUses; + for (OpOperand &use : asyncDot->getUses()) { + if (isa(use.getOwner())) + continue; + blockToUses[use.getOwner()->getBlock()].push_back(&use); + } + + for (auto &entry : blockToUses) { + auto &uses = entry.second; + std::sort(uses.begin(), uses.end(), [](OpOperand *lhs, OpOperand *rhs) { + return lhs->getOwner()->isBeforeInBlock(rhs->getOwner()); + }); + + auto firstUse = + std::find_if_not(uses.begin(), uses.end(), [](OpOperand *use) { + return isa(use->getOwner()) && + use->getOperandNumber() == 2; + }); + if (firstUse == uses.end()) + continue; + + Operation *firstConsumer = (*firstUse)->getOwner(); + SmallVector waitOperands; + for (auto useIt = firstUse; useIt != uses.end(); ++useIt) + waitOperands.push_back((*useIt)->get()); + + auto wait = getOrCreateWaitBefore(firstConsumer); + threadValuesThroughWait(wait, waitOperands, loop); + } + } +} + +void insertFinalWaitAfterLoop(scf::ForOp loop, + const ProperlyAsyncDots &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + OpBuilder builder(loop); + builder.setInsertionPointAfter(loop); + auto wait = triton::musa::SquadDotWaitOp::create(builder, loop.getLoc(), + ArrayRef{}); + + SmallVector waitOperands; + waitOperands.reserve(properlyAsyncDots.size()); + for (auto [asyncDotOp, iterArgIdx] : properlyAsyncDots) + waitOperands.push_back(loop.getResult(iterArgIdx)); + threadValuesThroughWait(wait, waitOperands, loop); +} + +SmallVector +collectPipelineSqmmaDots(scf::ForOp loop) { + SmallVector dots; + loop.getBody()->walk([&](Operation *op) { + if (auto sqmma = dyn_cast(op)) { + dots.push_back(sqmma); + return WalkResult::advance(); + } + if (isa(op)) + return WalkResult::skip(); + return WalkResult::advance(); + }); + return dots; +} + +} // namespace + +void pipelineSqmma(ModuleOp moduleOp, unsigned numStages) { + SmallVector loops; + moduleOp.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + for (scf::ForOp loop : loops) { + if (tt::getNumStagesOrDefault(loop, numStages) < 1) + continue; + + SmallVector sqmmaDots = + collectPipelineSqmmaDots(loop); + if (sqmmaDots.empty()) + continue; + + ProperlyAsyncDots properlyAsyncDots; + for (triton::musa::SquadDotOp sqmma : sqmmaDots) { + sqmma->setAttr("isAsync", BoolAttr::get(moduleOp.getContext(), true)); + if (auto iterArgIdx = dotCanBeProperlyAsync(sqmma, loop)) { + properlyAsyncDots[sqmma] = *iterArgIdx; + continue; + } + + auto wait = getOrCreateWaitAfter(sqmma); + (void)threadValuesThroughWait(wait, {sqmma.getResult()}, loop); + } + + insertAsyncSqmmaWaitsInLoop(loop, properlyAsyncDots); + insertFinalWaitAfterLoop(loop, properlyAsyncDots); + } +} + +} // namespace mlir::triton::musa::pipeline diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.h b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.h new file mode 100644 index 0000000000..7e954d5f96 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/SqmmaPipelineUtils.h @@ -0,0 +1,12 @@ +#ifndef TRITONMUSA_TRANSFORMS_SQMMA_PIPELINE_UTILS_H +#define TRITONMUSA_TRANSFORMS_SQMMA_PIPELINE_UTILS_H + +#include "mlir/IR/BuiltinOps.h" + +namespace mlir::triton::musa::pipeline { + +void pipelineSqmma(ModuleOp moduleOp, unsigned numStages); + +} // namespace mlir::triton::musa::pipeline + +#endif // TRITONMUSA_TRANSFORMS_SQMMA_PIPELINE_UTILS_H diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp new file mode 100644 index 0000000000..153b4c43f9 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp @@ -0,0 +1,203 @@ +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/BarrierUtils.h" +#include "TritonMUSACommon/MemDescUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include +#include +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +static Value materializeStaticTMETransactionBytes(Location loc, + ArrayRef shape, + Type elemType, + RewriterBase &rewriter) { + int64_t totalElements = 1; + for (int64_t dim : shape) { + if (dim <= 0) + return {}; + totalElements *= dim; + } + int64_t elemBits = elemType.getIntOrFloatBitWidth(); + if (elemBits <= 0 || (elemBits % 8) != 0) + return {}; + return arith::ConstantIntOp::create(rewriter, loc, + totalElements * (elemBits / 8), 32); +} + +static LogicalResult lowerDescriptorLoad(tt::DescriptorLoadOp op, + RewriterBase &rewriter) { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + auto descBlockTy = descTy.getSignlessBlockType(); + auto memDescTy = triton::musa::resolveDescriptorLoadLandingMemDescType(op); + if (failed(memDescTy)) + return op.emitOpError("descriptor load requires normalized canonical " + "landing memdesc encoding"); + + rewriter.setInsertionPoint(op); + auto coord = + triton::musa::materializeTMECoordValues(loc, op.getIndices(), rewriter); + if (failed(coord)) + return op.emitOpError("unsupported descriptor block rank for TME load"); + + auto alloc = ttg::LocalAllocOp::create(rewriter, loc, *memDescTy); + triton::musa::copyCanonicalLandingSqmmaAttrs(op, alloc.getOperation()); + auto config = triton::musa::resolveFinalTMECopyConfig( + *memDescTy, descBlockTy.getShape(), + triton::musa::TMECopyKind::GlobalToLocal); + if (failed(config)) + return op.emitOpError("unable to resolve final TME load config"); + + Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + auto barId = triton::musa::reserveFreshBarrierId(op); + if (failed(barId)) + return op.emitOpError("exhausted MUSA async barrier ids"); + Value barIdValue = arith::ConstantIntOp::create(rewriter, loc, *barId, 32); + + Value phaseInit = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value arriveCnt = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + Value alwaysIssue = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + Value totalBytes = materializeStaticTMETransactionBytes( + loc, descBlockTy.getShape(), op.getType().getElementType(), rewriter); + if (!totalBytes) + return op.emitOpError("unable to materialize descriptor load transaction " + "bytes"); + + triton::musa::InitArrivalOp::create(rewriter, loc, barIdValue, arriveCnt, + phaseInit); + triton::musa::BarrierAddTransOp::create(rewriter, loc, barIdValue, totalBytes, + alwaysIssue); + triton::musa::createAsyncTMECopyGlobalToLocal( + rewriter, loc, op.getDesc(), *coord, barIdValue, alloc, pred, *config); + triton::musa::ArriveBarrierNoRetOp::create(rewriter, loc, barIdValue, + alwaysIssue); + triton::musa::WaitBarrierOp::create(rewriter, loc, barIdValue, phaseInit); + + SmallVector users(op->getUsers().begin(), op->getUsers().end()); + Value tensorValue = op.getResult(); + Value localLoadValue; + auto getLocalLoadValue = [&]() -> Value { + if (!localLoadValue) { + localLoadValue = + ttg::LocalLoadOp::create(rewriter, loc, op.getType(), alloc) + .getResult(); + } + return localLoadValue; + }; + + for (Operation *user : users) { + bool isTensorViewUser = isa(user); + if (triton::musa::tryReplaceTensorUserWithMemDesc( + rewriter, tensorValue, alloc.getResult(), user) && + !isTensorViewUser) + continue; + if (isTensorViewUser && user->use_empty()) { + rewriter.eraseOp(user); + continue; + } + rewriter.setInsertionPoint(user); + user->replaceUsesOfWith(tensorValue, getLocalLoadValue()); + if (isTensorViewUser && user->use_empty()) + rewriter.eraseOp(user); + } + + rewriter.eraseOp(op); + return success(); +} + +static LogicalResult lowerDescriptorStore(tt::DescriptorStoreOp op, + RewriterBase &rewriter) { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + auto descBlockTy = descTy.getSignlessBlockType(); + auto memDescTy = triton::musa::resolveDescriptorStoreLandingMemDescType( + op, /*mutableMemory=*/false); + if (failed(memDescTy)) + return op.emitOpError("descriptor store requires normalized canonical " + "landing memdesc encoding"); + + rewriter.setInsertionPoint(op); + auto coord = + triton::musa::materializeTMECoordValues(loc, op.getIndices(), rewriter); + if (failed(coord)) + return op.emitOpError("unsupported descriptor block rank for TME store"); + + auto alloc = + ttg::LocalAllocOp::create(rewriter, loc, *memDescTy, op.getSrc()); + auto config = triton::musa::resolveFinalTMECopyConfig( + *memDescTy, descBlockTy.getShape(), + triton::musa::TMECopyKind::LocalToGlobal); + if (failed(config)) + return op.emitOpError("unable to resolve final TME store config"); + + Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + triton::musa::createAsyncTMECopyLocalToGlobal(rewriter, loc, op.getDesc(), + *coord, alloc, pred, *config); + triton::musa::TMEStoreCommitOp::create(rewriter, loc); + triton::musa::TMEStoreReadWaitOp::create(rewriter, loc); + + rewriter.eraseOp(op); + return success(); +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTMELOWERING +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUTMELoweringPass + : impl::TritonMUSAGPUTMELoweringBase { + void runOnOperation() override { + ModuleOp mod = getOperation(); + getContext().getOrLoadDialect(); + IRRewriter rewriter(&getContext()); + + for (tt::FuncOp func : mod.getOps()) { + SmallVector loadOps; + func.walk([&](tt::DescriptorLoadOp op) { loadOps.push_back(op); }); + for (tt::DescriptorLoadOp op : loadOps) { + if (!op->getBlock()) + continue; + if (failed(lowerDescriptorLoad(op, rewriter))) { + signalPassFailure(); + return; + } + } + + SmallVector storeOps; + func.walk([&](tt::DescriptorStoreOp op) { storeOps.push_back(op); }); + for (tt::DescriptorStoreOp op : storeOps) { + if (!op->getBlock()) + continue; + if (failed(lowerDescriptorStore(op, rewriter))) { + signalPassFailure(); + return; + } + } + } + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.cpp new file mode 100644 index 0000000000..895bb966a4 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.cpp @@ -0,0 +1,278 @@ +#include "TMEPipelineUtils.h" + +#include "Dialect/MUSA/IR/Dialect.h" +#include "TritonMUSACommon/MemDescUtils.h" +#include "TritonMUSACommon/TMEUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir::triton::musa::pipeline { +namespace { + +static void +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int maxStage) { + IRRewriter rewriter(forOp); + forOp.walk([&](tt::MakeTensorDescOp op) { + auto loc = op.getLoc(); + Value alloc = ttg::GlobalScratchAllocOp::create( + rewriter, loc, triton::getPointerType(rewriter.getI8Type()), + maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); +} + +static Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = + arith::ConstantIntOp::create(builder, loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = arith::MulIOp::create(builder, loc, tmaSizeVal, counter); + return tt::AddPtrOp::create(builder, loc, alloc.getType(), alloc, offset); +} + +static LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numBuffers, Value one, Value zero, + tt::CoarseSchedule &schedule) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + OpBuilder auxBuilder(forOp); + Value numBuffersVal = + arith::ConstantIntOp::create(auxBuilder, forOp.getLoc(), numBuffers, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + auto makeDescOp = cast(op); + + tt::OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp, schedule); + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = + subviewTMADescriptor(builder, builder.getLoc(), alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) + return failure(); + ttng::TensormapFenceproxyAcquireOp::create(builder, nextBuf); + Value nextDesc = ttng::ReinterpretTensorDescOp::create( + builder, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + Value nextCounter = createIncrementModulo( + builder, builder.getLoc(), counter, numBuffersVal, zero, one); + + IRRewriter rewriter(forOp); + nextCounter = triton::sinkValueRedefinition(rewriter, counter, nextCounter, + op->getBlock()); + + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + makeDescOp.erase(); + } + return success(); +} + +struct TMEStore { + Operation *op; + mlir::TypedValue desc; + mlir::TypedValue src; +}; + +static SmallVector getTMEStores(scf::ForOp forOp) { + SmallVector tmaStores; + forOp.getBody()->walk([&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + tmaStores.push_back({storeOp, storeOp.getDesc(), storeOp.getSrc()}); + } else if (isa(op)) { + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + return tmaStores; +} + +static FailureOr createStoreAlloc(scf::ForOp &forOp, + const TMEStore &store) { + OpBuilder builder(forOp); + auto memDescTy = triton::musa::resolveDescriptorStoreLandingMemDescType( + cast(store.op), + /*mutableMemory=*/true); + if (failed(memDescTy)) { + store.op->emitOpError("pipelined descriptor store requires normalized " + "canonical landing memdesc encoding"); + return failure(); + } + return ttg::LocalAllocOp::create(builder, store.op->getLoc(), *memDescTy) + .getResult(); +} + +static LogicalResult createMUSATMEStoreAsyncCopy(const TMEStore &store, + Value alloc) { + OpBuilder builder(store.op); + Location loc = store.op->getLoc(); + + auto descBlockTy = store.desc.getType().getSignlessBlockType(); + auto coord = triton::musa::materializeTMECoordValues( + loc, cast(store.op).getIndices(), builder); + if (failed(coord)) { + store.op->emitOpError("unable to materialize pipelined TME store block " + "info"); + return failure(); + } + auto issueMemDescTy = triton::musa::resolveDescriptorStoreLandingMemDescType( + cast(store.op), + /*mutableMemory=*/false); + if (failed(issueMemDescTy)) { + store.op->emitOpError("pipelined descriptor store requires normalized " + "immutable landing memdesc encoding"); + return failure(); + } + Value issueAlloc = alloc; + auto allocTy = cast(alloc.getType()); + if (allocTy != *issueMemDescTy) { + issueAlloc = + ttg::MemDescReinterpretOp::create(builder, loc, *issueMemDescTy, alloc); + } + auto config = triton::musa::resolveFinalTMECopyConfig( + *issueMemDescTy, descBlockTy.getShape(), + triton::musa::TMECopyKind::LocalToGlobal); + if (failed(config)) { + store.op->emitOpError("unable to resolve pipelined TME store config"); + return failure(); + } + + Value pred = arith::ConstantIntOp::create(builder, loc, 1, 1); + + triton::musa::TMEStoreReadWaitOp::create(builder, loc); + ttg::LocalStoreOp::create(builder, loc, store.src, alloc); + triton::musa::createAsyncTMECopyLocalToGlobal( + builder, loc, store.desc, *coord, issueAlloc, pred, *config); + triton::musa::TMEStoreCommitOp::create(builder, loc); + + store.op->erase(); + return success(); +} + +static void lowerTMADescriptorCreation(scf::ForOp forOp) { + tt::CoarseSchedule schedule(3); + (void)mlir::triton::musa::pipeline::lowerTMADescriptors(forOp, schedule); +} + +} // namespace + +scf::ForOp lowerTMADescriptors(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + llvm::MapVector tmaBufferMapping; + int maxStage = schedule.getNumStages() - 1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(&op)) { + maxStage += 1; + break; + } + } + allocTMABuffers(forOp, tmaBufferMapping, maxStage); + if (tmaBufferMapping.empty()) + return forOp; + + IRRewriter builder(forOp); + Location loc = forOp.getLoc(); + Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); + Value one = arith::ConstantIntOp::create(builder, loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); + for (int i = 0; i < static_cast(tmaBufferMapping.size()); ++i) + newOperands.push_back(zero); + + forOp = addIterArgsToLoop(builder, forOp, newOperands); + + auto tmaCounters = ArrayRef(forOp.getBody()->getArguments()) + .slice(tmaCounterArgsStartIdx); + + auto forYield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) + forYield.getResultsMutable().append(newOperands[i]); + + if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters, + maxStage, one, zero, schedule))) { + llvm::report_fatal_error("Failed to rewrite MUSA TMA descriptor updates"); + } + return forOp; +} + +FailureOr pipelineTMEStores(scf::ForOp forOp) { + SmallVector tmaStores = getTMEStores(forOp); + if (tmaStores.empty()) + return false; + + struct StoreAllocEntry { + ttg::MemDescType memDescTy; + Value alloc; + }; + + DenseMap storeToAlloc; + SmallVector allocs; + for (const TMEStore &store : tmaStores) { + if (!isa(store.op)) { + store.op->emitOpError("pipelined descriptor scatter/reduce is not " + "supported on MUSA"); + return failure(); + } + auto allocMemDescTy = + triton::musa::resolveDescriptorStoreLandingMemDescType( + cast(store.op), + /*mutableMemory=*/true); + if (failed(allocMemDescTy)) { + store.op->emitOpError("pipelined descriptor store requires normalized " + "canonical landing memdesc encoding"); + return failure(); + } + Value alloc; + for (const StoreAllocEntry &entry : allocs) { + if (entry.memDescTy == *allocMemDescTy) { + alloc = entry.alloc; + break; + } + } + if (!alloc) { + auto createdAlloc = createStoreAlloc(forOp, store); + if (failed(createdAlloc)) + return failure(); + alloc = *createdAlloc; + allocs.push_back(StoreAllocEntry{*allocMemDescTy, alloc}); + } + storeToAlloc[store.op] = alloc; + } + + bool hasDeviceSideTMA = llvm::any_of(tmaStores, [](const TMEStore &store) { + return !triton::isHostSideDescriptor(store.desc); + }); + for (const TMEStore &store : tmaStores) { + if (failed(createMUSATMEStoreAsyncCopy(store, storeToAlloc[store.op]))) + return failure(); + } + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + triton::musa::TMEStoreReadWaitOp::create(builder, forOp->getLoc()); + for (const StoreAllocEntry &entry : allocs) + ttg::LocalDeallocOp::create(builder, forOp->getLoc(), entry.alloc); + + if (hasDeviceSideTMA) + lowerTMADescriptorCreation(forOp); + return true; +} + +} // namespace mlir::triton::musa::pipeline diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.h b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.h new file mode 100644 index 0000000000..5db9b9d9d6 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMEPipelineUtils.h @@ -0,0 +1,16 @@ +#ifndef TRITONMUSA_TRANSFORMS_TME_PIPELINE_UTILS_H +#define TRITONMUSA_TRANSFORMS_TME_PIPELINE_UTILS_H + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" + +namespace mlir::triton::musa::pipeline { + +scf::ForOp lowerTMADescriptors(scf::ForOp forOp, + mlir::triton::CoarseSchedule &schedule); +FailureOr pipelineTMEStores(scf::ForOp forOp); + +} // namespace mlir::triton::musa::pipeline + +#endif // TRITONMUSA_TRANSFORMS_TME_PIPELINE_UTILS_H diff --git a/third_party/mthreads/proton/.gitignore b/third_party/mthreads/proton/.gitignore new file mode 100644 index 0000000000..9d7fd1f325 --- /dev/null +++ b/third_party/mthreads/proton/.gitignore @@ -0,0 +1,6 @@ +build/ +proton.egg-info +proton/_C/libproton.so + +*.hatchet +*.chrome_trace diff --git a/third_party/mthreads/proton/CMakeLists.txt b/third_party/mthreads/proton/CMakeLists.txt new file mode 100644 index 0000000000..abd90d689f --- /dev/null +++ b/third_party/mthreads/proton/CMakeLists.txt @@ -0,0 +1,84 @@ +project(Proton LANGUAGES CXX) + +set(PROTON_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}/csrc") +set(PROTON_COMMON_DIR "${CMAKE_CURRENT_SOURCE_DIR}/common") + +# ============ Check for includes ============= +if(NOT CUPTI_INCLUDE_DIR) + message(FATAL_ERROR "CUPTI include directory not defined") +endif() +if(NOT ROCTRACER_INCLUDE_DIR) + message(FATAL_ERROR "ROCTRACER include directory not defined") +endif() +if(NOT JSON_INCLUDE_DIR) + message(FATAL_ERROR "JSON include directory not defined") +endif() + +# ============ Dependencies ============= +find_package(Python3 REQUIRED Interpreter Development.Module) +find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") + +# ============ Define a GLOBAL property to store object-libraries ============ +set_property(GLOBAL PROPERTY PROTON_LIBS "") + +# ============ Define a function to create object libraries ============ +function(add_proton_library name) + add_library(${name} OBJECT ${ARGN}) + + target_link_libraries(${name} PRIVATE Python3::Module pybind11::headers) + + # Use system to skip warnings caused by legacy clang compilers + target_include_directories(${name} + SYSTEM PRIVATE + "${ROCTRACER_INCLUDE_DIR}" + ) + + target_include_directories(${name} + PRIVATE + "${CUPTI_INCLUDE_DIR}" + "${JSON_INCLUDE_DIR}" + "${PROTON_COMMON_DIR}/include" + "${PROTON_SRC_DIR}/include" + ) + + # If HIP is AMD-based + target_compile_definitions(${name} PRIVATE __HIP_PLATFORM_AMD__) + + # Append this library name to the GLOBAL property "PROTON_LIBS" + set_property(GLOBAL APPEND PROPERTY PROTON_LIBS ${name}) +endfunction() + +# ============ Add subdirectory with actual code that calls add_proton_library ============ +add_subdirectory("${PROTON_COMMON_DIR}") +add_subdirectory("${PROTON_SRC_DIR}") + +# ============ Add subdirectory with proton tests ============ +add_subdirectory(test) + +# ============ Possibly handle macOS specifics ============ +if(APPLE) + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + # Other platforms build with -flto, but we found that this adds significant overhead to our macos CI without providing a major benefit. + set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup") +endif() + +# ============ Collect all object libraries from property and build final shared lib ============ +get_property(_proton_obj_libs GLOBAL PROPERTY PROTON_LIBS) + +if(NOT _proton_obj_libs) + message(WARNING "No object libraries were defined in 'PROTON_LIBS'!") +endif() + +set(_proton_obj_sources "") +foreach(_lib IN LISTS _proton_obj_libs) + list(APPEND _proton_obj_sources $) + message(STATUS "Collecting object files from ${_lib}") +endforeach() + +add_library(proton SHARED ${_proton_obj_sources}) + +target_link_libraries(proton PRIVATE Python3::Module) +# Apply any macOS linker flags or extra link options +if(PROTON_PYTHON_LDFLAGS) + target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) +endif() diff --git a/third_party/mthreads/proton/Dialect/CMakeLists.txt b/third_party/mthreads/proton/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..9fcd0be2e1 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/CMakeLists.txt @@ -0,0 +1,17 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + set(PROTON_PLUGIN_LINK_LIBS + ProtonToProtonGPU + ProtonGPUToLLVM + ProtonNVIDIAGPUToLLVM + ProtonAnalysis + ) + if(TRITON_ENABLE_AMD) + list(APPEND PROTON_PLUGIN_LINK_LIBS ProtonAMDGPUToLLVM) + endif() + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ${PROTON_PLUGIN_LINK_LIBS}) + target_link_libraries(TritonProton PRIVATE Python3::Module pybind11::headers) +endif() diff --git a/third_party/mthreads/proton/Dialect/include/Analysis/ScopeIdAllocation.h b/third_party/mthreads/proton/Dialect/include/Analysis/ScopeIdAllocation.h new file mode 100644 index 0000000000..8a0fa34bc6 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Analysis/ScopeIdAllocation.h @@ -0,0 +1,91 @@ +#ifndef PROTON_ANALYSIS_SCOPE_ID_ALLOCATION_H +#define PROTON_ANALYSIS_SCOPE_ID_ALLOCATION_H + +#include "mlir/IR/Operation.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" +#include +#include +#include +#include + +namespace mlir { +namespace triton::proton { + +class ScopeIdAllocation { +public: + using ScopeId = size_t; + // id -> name + using ScopeIdName = std::vector>; + // id -> parent id + using ScopeIdParent = std::vector>; + + ScopeIdAllocation() = default; + explicit ScopeIdAllocation(FunctionOpInterface op) : funcOp(op) { run(); } + + ScopeId getOpScopeId(Operation *op) const { + if (auto recordOp = dyn_cast(op)) { + return opToIdMap.lookup(recordOp); + } + llvm_unreachable("unexpected operation type"); + } + + ScopeIdName getScopeIdNames() const { + ScopeIdName scopeIdNames; + for (const auto &[id, name] : idToNameMap) { + scopeIdNames.push_back({id, name.str()}); + } + return scopeIdNames; + } + + ScopeIdParent getScopeIdParents() const { return scopeParentIds; } + + size_t getNumScopes() const { return idToNameMap.size(); } + +private: + using VirtualBlock = std::pair; + + void run(); + void reachability(); + void liveness(); + void dominance(); + void visitTerminator(Operation *op, SmallVector &successors); + + FunctionOpInterface funcOp; + llvm::DenseMap idToNameMap; + llvm::DenseMap opToIdMap; + ScopeIdParent scopeParentIds; +}; + +class ModuleScopeIdAllocation : public triton::CallGraph { +public: + using FuncOffsetMapT = + llvm::DenseMap; + // Alias for per-function name and parent maps + using ScopeIdNameMap = + llvm::DenseMap; + using ScopeIdParentMap = + llvm::DenseMap; + + explicit ModuleScopeIdAllocation(ModuleOp moduleOp); + + ScopeIdAllocation::ScopeId getOpScopeId(Operation *op) const; + ScopeIdAllocation::ScopeIdName getScopeIdNames(triton::FuncOp funcOp) const; + ScopeIdAllocation::ScopeIdName getScopeIdNames() const; + ScopeIdAllocation::ScopeIdParent + getScopeIdParents(triton::FuncOp funcOp) const; + ScopeIdAllocation::ScopeIdParent getScopeIdParents() const; + +private: + FuncOffsetMapT funcScopeIdMap; + // Precomputed per-function mappings + ScopeIdNameMap scopeIdNames; + ScopeIdParentMap scopeIdParents; +}; + +} // namespace triton::proton +} // namespace mlir + +#endif // PROTON_ANALYSIS_SCOPE_ID_ALLOCATION_H diff --git a/third_party/mthreads/proton/Dialect/include/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/CMakeLists.txt new file mode 100644 index 0000000000..5c6ab30594 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Dialect) +add_subdirectory(Conversion) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..13d8b85e52 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(ProtonToProtonGPU) +add_subdirectory(ProtonGPUToLLVM) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..4860521328 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name ProtonGPUToLLVM) +add_public_tablegen_target(ProtonGPUConversionPassIncGen) + +add_subdirectory(ProtonNvidiaGPUToLLVM) +add_subdirectory(ProtonAMDGPUToLLVM) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h new file mode 100644 index 0000000000..04cba105af --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h @@ -0,0 +1,28 @@ +#ifndef PROTONGPU_TO_LLVM_PASSES_H +#define PROTONGPU_TO_LLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; + +namespace triton::proton::gpu { + +#define GEN_PASS_DECL +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createAddSchedBarriersPass(); + +#define GEN_PASS_REGISTRATION +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h.inc" + +} // namespace triton::proton::gpu + +} // namespace mlir + +#endif // PROTONGPU_TO_LLVM_PASSES_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.td b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.td new file mode 100644 index 0000000000..800694b1be --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.td @@ -0,0 +1,34 @@ +#ifndef PROTONGPU_TO_LLVM_PASSES +#define PROTONGPU_TO_LLVM_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateProtonSharedMemoryPass : Pass<"allocate-proton-shared-memory", "mlir::ModuleOp"> { + let summary = "Update metadata for proton shared memory allocation"; + let description = [{ + This pass updates the amount of shared/local memory used by + proton intra kernel profiling. + }]; + + let dependentDialects = ["ProtonDialect", + "gpu::ProtonGPUDialect"]; +} + +def AllocateProtonGlobalScratchBufferPass : Pass<"allocate-proton-global-scratch-buffer", "mlir::ModuleOp"> { + let summary = "Update metadata for proton global scratch buffer allocation"; + let description = [{ + This pass updates the amount of global memory used by + proton intra kernel profiling. + }]; + + let dependentDialects = ["ProtonDialect", + "gpu::ProtonGPUDialect"]; +} + +def AddSchedBarriers : Pass<"add-sched-barriers", "mlir::ModuleOp"> { + let constructor = "mlir::triton::proton::gpu::createAddSchedBarriersPass()"; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect"]; +} + +#endif // PROTONGPU_TO_LLVM_PASSES diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h new file mode 100644 index 0000000000..869d4864bc --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h @@ -0,0 +1,22 @@ +#ifndef PROTONGPU_TO_LLVM_PATTERN_PROTONGPUOP_TO_LLVM_H +#define PROTONGPU_TO_LLVM_PATTERN_PROTONGPUOP_TO_LLVM_H + +#include "Conversion/ProtonGPUToLLVM/TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::triton { +namespace proton::gpu { + +void populateProtonGPUOpPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateTypeConversions(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo); + +} // namespace proton::gpu +} // namespace mlir::triton + +#endif // PROTONGPU_TO_LLVM_PATTERN_PROTONGPUOP_TO_LLVM_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.h new file mode 100644 index 0000000000..9e0304dbe5 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.h @@ -0,0 +1,20 @@ +#ifndef PROTONGPU_TO_LLVM_AMD_PATTERN_PROTONGPUOP_TO_LLVM_H +#define PROTONGPU_TO_LLVM_AMD_PATTERN_PROTONGPUOP_TO_LLVM_H + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton { +namespace proton::gpu { +namespace AMD { + +void populateProtonGPUOpAMDPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +} // namespace AMD +} // namespace proton::gpu +} // namespace mlir::triton + +#endif // PROTONGPU_TO_LLVM_AMD_PATTERN_PROTONGPUOP_TO_LLVM_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..b3b7f514b5 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name ProtonAMDGPUToLLVM) +add_public_tablegen_target(ProtonAMDGPUConversionPassIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h new file mode 100644 index 0000000000..3cd281b5c9 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h @@ -0,0 +1,30 @@ +#ifndef PROTONGPU_TO_LLVM_PROTONAMDGPU_TO_LLVM_PASSES_H +#define PROTONGPU_TO_LLVM_PROTONAMDGPU_TO_LLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton::proton::gpu { + +#define GEN_PASS_DECL +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h.inc" + +std::unique_ptr> +createConvertProtonAMDGPUToLLVMPass(std::string arch = ""); + +#define GEN_PASS_REGISTRATION +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h.inc" + +} // namespace triton::proton::gpu + +} // namespace mlir + +#endif // PROTONGPU_TO_LLVM_PROTONAMDGPU_TO_LLVM_PASSES_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.td b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.td new file mode 100644 index 0000000000..632d424dd5 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.td @@ -0,0 +1,31 @@ +#ifndef PROTONAMDGPU_TO_LLVM_PASSES +#define PROTONAMDGPU_TO_LLVM_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertProtonAMDGPUToLLVM : Pass<"convert-proton-amd-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert ProtonGPU to LLVM"; + let description = [{ + Convert ProtonGPU to LLVM using AMD-specific lowering patterns. + }]; + let constructor = "mlir::triton::proton::gpu::createConvertProtonAMDGPUToLLVMPass(\"\")"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect", + "mlir::triton::proton::ProtonDialect", + "mlir::triton::proton::gpu::ProtonGPUDialect"]; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942"> + ]; +} + +#endif // PROTONAMDGPU_TO_LLVM_PASSES diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h new file mode 100644 index 0000000000..4ebf97d4fa --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h @@ -0,0 +1,40 @@ +#ifndef PROTONGPU_TO_LLVM_TARGETINFO_AMD_H +#define PROTONGPU_TO_LLVM_TARGETINFO_AMD_H + +#include "Conversion/ProtonGPUToLLVM/TargetInfoBase.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" // TODO(fywkevin): move amd TargetInfo.h to include/ +#include + +namespace mlir::triton::proton::gpu::AMD { +class TargetInfo : public mlir::triton::proton::gpu::TargetInfoBase { +public: + explicit TargetInfo(const mlir::triton::AMD::TargetInfo &helper, + std::string arch) + : mlir::triton::proton::gpu::TargetInfoBase(helper), + arch(std::move(arch)) {} + + const mlir::triton::AMD::TargetInfo &getTritonTargetInfo() const override { + return static_cast(helper); + } + + Value clock(ConversionPatternRewriter &rewriter, Location loc, + bool isClock64) const override; + + Value globalTime(ConversionPatternRewriter &rewriter, + Location loc) const override; + + Value processorId(ConversionPatternRewriter &rewriter, + Location loc) const override; + + int getAddressSpace(Attribute addressSpace) const override; + + int getIndexPtrAddrSpace() const override; + + ~TargetInfo() = default; + +private: + std::string arch; +}; +} // namespace mlir::triton::proton::gpu::AMD + +#endif // PROTONGPU_TO_LLVM_TARGETINFO_AMD_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..baca67454b --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name ProtonNvidiaGPUToLLVM) +add_public_tablegen_target(ProtonNvidiaGPUConversionPassIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.h new file mode 100644 index 0000000000..083d344b79 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.h @@ -0,0 +1,20 @@ +#ifndef PROTONGPU_TO_LLVM_NVIDIA_PATTERN_PROTONGPUOP_TO_LLVM_H +#define PROTONGPU_TO_LLVM_NVIDIA_PATTERN_PROTONGPUOP_TO_LLVM_H + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton { +namespace proton::gpu { +namespace NVIDIA { + +void populateProtonGPUOpNvidiaPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +} // namespace NVIDIA +} // namespace proton::gpu +} // namespace mlir::triton + +#endif // PROTONGPU_TO_LLVM_NVIDIA_PATTERN_PROTONGPUOP_TO_LLVM_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h new file mode 100644 index 0000000000..0b891448d1 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h @@ -0,0 +1,31 @@ +#ifndef PROTONGPU_TO_LLVM_PROTONNVIDIAGPU_TO_LLVM_PASSES_H +#define PROTONGPU_TO_LLVM_PROTONNVIDIAGPU_TO_LLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton::proton::gpu { + +#define GEN_PASS_DECL +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h.inc" + +std::unique_ptr> +createConvertProtonNvidiaGPUToLLVMPass(int32_t computeCapability = 80, + int32_t ptxVersion = 80); + +#define GEN_PASS_REGISTRATION +#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h.inc" + +} // namespace triton::proton::gpu + +} // namespace mlir + +#endif // PROTONGPU_TO_LLVM_PROTONNVIDIAGPU_TO_LLVM_PASSES_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.td b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.td new file mode 100644 index 0000000000..21f055f1ce --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.td @@ -0,0 +1,34 @@ +#ifndef PROTONNVIDIAGPU_TO_LLVM_PASSES +#define PROTONNVIDIAGPU_TO_LLVM_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertProtonNvidiaGPUToLLVM : Pass<"convert-proton-nvidia-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert ProtonGPU to LLVM"; + let description = [{ + Convert ProtonGPU to LLVM using Nvidia-specific lowering patterns. + }]; + let constructor = "mlir::triton::proton::gpu::createConvertProtonNvidiaGPUToLLVMPass(80, 80)"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::NVVM::NVVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::proton::ProtonDialect", + "mlir::triton::proton::gpu::ProtonGPUDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability">, + Option<"ptxVersion", "ptx-version", + "int32_t", /*default*/"80", + "PTX version">, + ]; +} + +#endif // PROTONNVIDIAGPU_TO_LLVM_PASSES diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h new file mode 100644 index 0000000000..277445911a --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h @@ -0,0 +1,34 @@ +#ifndef PROTONGPU_TO_LLVM_TARGETINFO_NVIDIA_H +#define PROTONGPU_TO_LLVM_TARGETINFO_NVIDIA_H + +#include "Conversion/ProtonGPUToLLVM/TargetInfoBase.h" +#include "compat/TargetInfo.h" + +namespace mlir::triton::proton::gpu::NVIDIA { +class TargetInfo : public mlir::triton::proton::gpu::TargetInfoBase { +public: + explicit TargetInfo(const mlir::triton::NVIDIA::TargetInfo &helper) + : mlir::triton::proton::gpu::TargetInfoBase(helper) {} + + const mlir::triton::NVIDIA::TargetInfo &getTritonTargetInfo() const override { + return static_cast(helper); + } + + Value clock(ConversionPatternRewriter &rewriter, Location loc, + bool isClock64) const override; + + Value globalTime(ConversionPatternRewriter &rewriter, + Location loc) const override; + + Value processorId(ConversionPatternRewriter &rewriter, + Location loc) const override; + + int getAddressSpace(Attribute addressSpace) const override; + + int getIndexPtrAddrSpace() const override; + + ~TargetInfo() {} +}; +} // namespace mlir::triton::proton::gpu::NVIDIA + +#endif // PROTONGPU_TO_LLVM_TARGETINFO_NVIDIA_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/TargetInfoBase.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 0000000000..b915575b53 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,43 @@ +#ifndef PROTONGPU_TO_LLVM_TARGETINFO_BASE_H +#define PROTONGPU_TO_LLVM_TARGETINFO_BASE_H + +#include "mlir/IR/Attributes.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::triton::proton::gpu { + +class TargetInfoBase { +public: + explicit TargetInfoBase(const mlir::triton::TargetInfoBase &helper) + : helper(helper) {} + + virtual const mlir::triton::TargetInfoBase &getTritonTargetInfo() const { + return helper; + } + + // Return the local cycle counter value. + virtual Value clock(ConversionPatternRewriter &rewriter, Location loc, + bool isClock64) const = 0; + + // Return the global cycle counter value (i.e., synchronized across SMs) in + // nanoseconds, regardless of the clock frequency. + virtual Value globalTime(ConversionPatternRewriter &rewriter, + Location loc) const = 0; + + virtual Value processorId(ConversionPatternRewriter &rewriter, + Location loc) const = 0; + + virtual int getAddressSpace(Attribute addressSpace) const = 0; + + virtual int getIndexPtrAddrSpace() const = 0; + + virtual ~TargetInfoBase() = default; + +protected: + const mlir::triton::TargetInfoBase &helper; +}; +} // namespace mlir::triton::proton::gpu + +#endif // PROTONGPU_TO_LLVM_TARGETINFO_BASE_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Utility.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Utility.h new file mode 100644 index 0000000000..b459f0511f --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonGPUToLLVM/Utility.h @@ -0,0 +1,55 @@ +#ifndef PROTONGPU_TO_LLVM_UTILITY_H +#define PROTONGPU_TO_LLVM_UTILITY_H + +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +Value getRawThreadId(OpBuilder &rewriter, Location loc); + +namespace LLVM { + +struct SegmentObject { + Value base; + Value segmentBase; + Value indexPtr; + + SegmentObject(Value base, Value segmentBase, Value indexPtr) + : base(base), segmentBase(segmentBase), indexPtr(indexPtr) {} + + Value getStruct(Location loc, ConversionPatternRewriter &rewriter); + + static LLVMStructType getStructType(MLIRContext *ctx, int memorySpace, + int indexPtrAddrSpace); + + static SegmentObject fromStruct(Location loc, Value segmentStruct, + ConversionPatternRewriter &rewriter); +}; + +} // namespace LLVM + +namespace triton { +namespace proton::gpu { + +struct CircularStoreDataPack { + Value isWriter; + Value record; + Value ptr; + uint32_t addrSpace; +}; + +CircularStoreDataPack +lowerCircularStoreOpHelper(CircularStoreOp op, Value segmentStruct, + ConversionPatternRewriter &rewriter); + +SmallVector getTritonFunctions(ModuleOp mod); + +} // namespace proton::gpu +} // namespace triton + +} // namespace mlir + +#endif // PROTONGPU_TO_LLVM_UTILITY_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/CMakeLists.txt new file mode 100644 index 0000000000..e58639d317 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name ProtonToProtonGPU) +add_public_tablegen_target(ProtonToProtonGPUIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h new file mode 100644 index 0000000000..14bcab1885 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h @@ -0,0 +1,31 @@ +#ifndef PROTON_TO_PROTONGPU_PASSES_H +#define PROTON_TO_PROTONGPU_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h" + +namespace mlir::triton::proton { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h.inc" + +std::unique_ptr> createConvertProtonToProtonGPUPass( + MetricType metricType = MetricType::CYCLE, + SamplingStrategy samplingStrategy = SamplingStrategy::NONE, + llvm::StringRef samplingOptions = "", + gpu::Granularity granularity = gpu::Granularity::WARP, + gpu::BufferStrategy bufferStrategy = gpu::BufferStrategy::CIRCULAR, + gpu::BufferType bufferType = gpu::BufferType::SHARED, + int32_t bufferSize = 0, int32_t maxSharedMemSize = 32768, + int64_t profileScratchSize = 32768, int32_t profileScratchAlignment = 128, + bool clkExt = false); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h.inc" + +} // namespace mlir::triton::proton + +#endif // PROTON_TO_PROTONGPU_PASSES_H diff --git a/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.td b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.td new file mode 100644 index 0000000000..bf5b22b1cf --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.td @@ -0,0 +1,80 @@ +#ifndef PROTON_TO_PROTONGPU_PASSES +#define PROTON_TO_PROTONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertProtonToProtonGPU: Pass<"convert-proton-to-protongpu", "mlir::ModuleOp"> { + let summary = "Lowering pass of ProtonIR to ProtonGPU IR"; + + let description = "Convert the Proton Op into ProtonGPU Op. This includes scaffolding operations" + "such as allocation for internal profiling buffers, resources binding, and final cleanup."; + + let constructor = "createConvertProtonToProtonGPUPass()"; + + let dependentDialects = ["ProtonDialect", + "gpu::ProtonGPUDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"metricType", "metric-type", + "MetricType", /*default*/"MetricType::CYCLE", + "The performance counter metric type we are profiling", + /*parser*/[{::llvm::cl::values( + clEnumValN(MetricType::CYCLE, "cycle", "Cycle") + )}]>, + Option<"granularity", "granularity", + "gpu::Granularity", /*default*/"gpu::Granularity::WARP", + "Profiling granularity: warp, warp_group, or cta", + /*parser*/[{::llvm::cl::values( + clEnumValN(gpu::Granularity::THREAD, "thread", "Thread"), + clEnumValN(gpu::Granularity::WARP, "warp", "Warp"), + clEnumValN(gpu::Granularity::WARP_2, "warp-2", "2 Warps"), + clEnumValN(gpu::Granularity::WARP_4, "warp-4", "4 Warps"), + clEnumValN(gpu::Granularity::WARP_8, "warp-8", "8 Warps"), + clEnumValN(gpu::Granularity::CTA, "cta", "CTA"), + clEnumValN(gpu::Granularity::WARP_GROUP, "warp-group", "Warp Group"), + clEnumValN(gpu::Granularity::WARP_GROUP_2, "warp-group-2", "2 Warp Groups"), + clEnumValN(gpu::Granularity::WARP_GROUP_4, "warp-group-4", "4 Warp Groups"), + clEnumValN(gpu::Granularity::WARP_GROUP_8, "warp-group-8", "8 Warp Groups") + )}]>, + Option<"samplingStrategy", "sampling-strategy", + "SamplingStrategy", /*default*/"SamplingStrategy::NONE", + "Profiling sampling strategy", + /*parser*/[{::llvm::cl::values( + clEnumValN(SamplingStrategy::NONE, "none", "No Sampling"), + clEnumValN(SamplingStrategy::SELECTIVE, "selective", "Selective Sampling") + )}]>, + Option<"samplingOptions", "sampling-options", + "std::string", /*default*/"\"\"", + "Profiling sampling options">, + Option<"bufferStrategy", "buffer-strategy", "gpu::BufferStrategy", /*default*/"gpu::BufferStrategy::CIRCULAR", + "Profiler buffer recording strategy (circular or flush)", + /*parser*/[{::llvm::cl::values( + clEnumValN(gpu::BufferStrategy::CIRCULAR, "circular", "Circular Buffer"), + clEnumValN(gpu::BufferStrategy::FLUSH, "flush", "Flush Buffer") + )}]>, + Option<"bufferType", "buffer-type", "gpu::BufferType", /*default*/"gpu::BufferType::SHARED", + "Internal buffer type (SHARED, GLOBAL) that stores the profiling data", + /*parser*/[{::llvm::cl::values( + clEnumValN(gpu::BufferType::SHARED, "shared", "Shared Memory"), + clEnumValN(gpu::BufferType::GLOBAL, "global", "Global Memory") + )}]>, + Option<"bufferSize", "buffer-size", "int32_t", /*default*/"0", + "Internal buffer byte size that stores the profiling data. 0 means auto-size based on the device's `maxSharedMemSize`">, + Option<"maxSharedMemSize", "max-shared-mem-size", + "int32_t", /*default*/"32768", + "Maximum available shared memory size per CTA">, + Option<"profileScratchSize", "scratch-mem-size", + "int64_t", /*default*/"32768", + "Profiler global scratch memory size per CTA">, + Option<"profileScratchAlignment", "scratch-mem-alignment", + "int32_t", /*default*/"128", + "Profiler global scratch memory alignment">, + Option<"clockExtension", "clock-extension", + "bool", /*default*/"false", + "Use long clock if true, otherwise use 32-bit clock">, + ]; +} + +#endif diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..adc6d78575 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Proton) +add_subdirectory(ProtonGPU) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..9cb56c8d7b --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonAttrDefs ProtonAttrDefs dialects/ -gen-attrdef-doc) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 0000000000..cfa4e771d8 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,14 @@ +#ifndef DIALECT_PROTON_IR_DIALECT_H_ +#define DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/Dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_OP_CLASSES +#include "proton/Dialect/include/Dialect/Proton/IR/Ops.h.inc" + +#endif // DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 0000000000..7c55136dfe --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,46 @@ +#ifndef PROTON_ATTR_DEFS +#define PROTON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +def MetricTypeAttr : I32EnumAttr< + "MetricType", "The type of metric to be profiled", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; + let description = [{ + Attribute to indicate the metric to be profiled. + The following metrics are supported: + - CYCLE: Cycle count metric. + }]; +} + +def SamplingStrategyAttr : I32EnumAttr< + "SamplingStrategy", "The strategy for sampling the profiling data", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"SELECTIVE", 1, "selective">, + ]> { + let cppNamespace = "::mlir::triton::proton"; + let description = [{ + Attribute to indicate the sampling strategy for profiling. + The following sampling strategies are supported: + - NONE: No sampling. + - SELECTIVE: Manually select a couple of instances to profile. + }]; +} + +def ModeAttr : I32EnumAttr< + "Mode", "The mode of profiling", + [ + I32EnumAttrCase<"DEFAULT", 0, "default">, + I32EnumAttrCase<"MMA", 1, "mma">, + ]> { + let cppNamespace = "::mlir::triton::proton"; + let description = [{ + Attribute to indicate the mode of profiling, which specifies passes and instructions to monitor. + }]; +} + +#endif // PROTON_ATTR_DEFS diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 0000000000..8d35a5ec10 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,20 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; + + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 0000000000..265ce96f07 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,45 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "proton/Dialect/include/Dialect/Proton/IR/ProtonDialect.td" +include "proton/Dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td" + +class PT_Op traits = []> : + Op { +} + +def PT_RecordOp : PT_Op<"record", [ + MemoryEffects<[MemRead, MemWrite]> +]> { + let summary = "Record an event"; + + let description = [{ + This operation annotates a region of IR where events are recorded. + Events can be classified as hardware or software events. + Hardware events are provided by the hardware performance counters obtained in later passes that convert Triton to target-specific IR. + Software events are provided by the user or the compiler. + + Example: + + ```mlir + proton.record start "name0" + ... + proton.record end "name0" + ``` + + Scope names cannot be reused within the same function. + }]; + let arguments = ( + ins UnitAttr: $isStart, + StrAttr: $name + ); + + let assemblyFormat = "(`start` $isStart^):(`end`)? $name attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..05222e610d --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/CMakeLists.txt @@ -0,0 +1,23 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonGPUOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton_gpu) +add_mlir_doc(ProtonGPUOps ProtonGPUOps dialects/ -gen-op-doc) +add_mlir_doc(ProtonGPUDialect ProtonGPUDialect dialects/ -gen-dialect-doc) +add_public_tablegen_target(ProtonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonGPUAttrDefs.td) +mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonGPUAttrDefs ProtonGPUAttrDefs dialects/ -gen-attrdef-doc) +add_public_tablegen_target(ProtonGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS ProtonGPUTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(ProtonGPUTypesIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h new file mode 100644 index 0000000000..ed45cbedf4 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h @@ -0,0 +1,35 @@ +#ifndef DIALECT_PROTONGPU_IR_DIALECT_H_ +#define DIALECT_PROTONGPU_IR_DIALECT_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h.inc" +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Ops.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/AttrDefs.h.inc" + +namespace mlir { +namespace triton { +namespace proton { +namespace gpu { + +const int getBytesPerClockEntry(); + +const int getCircularHeaderSize(); + +const int getTotalNumWarps(ModuleOp mod); + +} // namespace gpu +} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // DIALECT_PROTONGPU_IR_DIALECT_H_ diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.td b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.td new file mode 100644 index 0000000000..b6f759a309 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.td @@ -0,0 +1,71 @@ +#ifndef PROTONGPU_ATTR_DEFS +#define PROTONGPU_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td" + +def GranularityAttr : I32EnumAttr< + "Granularity", "The granularity of the profiling metric", + [ + I32EnumAttrCase<"THREAD", 0, "thread">, + I32EnumAttrCase<"WARP", 1, "warp">, + I32EnumAttrCase<"WARP_2", 2, "warp_2">, + I32EnumAttrCase<"WARP_4", 3, "warp_4">, + I32EnumAttrCase<"WARP_8", 4, "warp_8">, + I32EnumAttrCase<"CTA", 5, "cta">, + I32EnumAttrCase<"WARP_GROUP", 6, "warp_group">, + I32EnumAttrCase<"WARP_GROUP_2", 7, "warp_group_2">, + I32EnumAttrCase<"WARP_GROUP_4", 8, "warp_group_4">, + I32EnumAttrCase<"WARP_GROUP_8", 9, "warp_group_8">, + ]> { + let cppNamespace = "::mlir::triton::proton::gpu"; + let description = [{ + The granularity can be per CTA, per warp, or per warp group. + The following granularity levels are supported: + - THREAD: Metrics are recorded per thread. + - CTA: Metrics are recorded per CTA. + - WARP: Metrics are recorded per warp. + - WARP_2, WARP_4, WARP_8: Metrics are recorded for every 2, 4, or 8 warps, respectively. + - WARP_GROUP: Metrics are recorded per warp group. + - WARP_GROUP_2, WARP_GROUP_4, WARP_GROUP_8: Metrics are recorded for every 2, 4, or 8 warp groups, respectively. + }]; +} + +def BufferStrategyAttr : I32EnumAttr< + "BufferStrategy", "The strategy for buffer management", + [ + I32EnumAttrCase<"CIRCULAR", 0, "circular">, + I32EnumAttrCase<"FLUSH", 1, "flush">, + ]> { + let cppNamespace = "::mlir::triton::proton::gpu"; + let description = [{ + The following buffer management strategies are supported: + - CIRCULAR: Circular buffer management strategy. Out of space is handled by overwriting the oldest data. + - FLUSH: Flush buffer management strategy. Once the GPU buffer is full, data is flushed to the host. + }]; +} + +def BufferTypeAttr : I32EnumAttr< + "BufferType", "The type of internal buffer to be used", + [ + I32EnumAttrCase<"SHARED", 1, "shared">, + I32EnumAttrCase<"GLOBAL", 2, "global">, + ]> { + let cppNamespace = "::mlir::triton::proton::gpu"; + let description = [{ + The following buffer types are supported: + - SHARED: Shared memory buffer type. + - GLOBAL: Profiling data get stored directly in global memory, but may be cached in L2/L1. + }]; +} + +def PTG_GlobalMemorySpace : AttrDef { + let cppNamespace = "::mlir::triton::proton::gpu"; + let mnemonic = "global_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to global memory. + }]; +} + +#endif // PROTONGPU_ATTR_DEFS diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td new file mode 100644 index 0000000000..9fe146f5e7 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td @@ -0,0 +1,28 @@ +#ifndef PROTONGPU_DIALECT +#define PROTONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def ProtonGPU_Dialect : Dialect { + let name = "proton_gpu"; + let cppNamespace = "::mlir::triton::proton::gpu"; + + let description = [{ + Proton GPU dialect. + }]; + + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + "triton::proton::ProtonDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif // PROTONGPU_DIALECT diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUOps.td b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUOps.td new file mode 100644 index 0000000000..a25755aff5 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUOps.td @@ -0,0 +1,204 @@ +#ifndef PROTONGPU_OPS +#define PROTONGPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "proton/Dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUTypes.td" + +//===----------------------------------------------------------------------===// +// Resources +//===----------------------------------------------------------------------===// + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +//===----------------------------------------------------------------------===// +// Base Class +//===----------------------------------------------------------------------===// + +class PTG_Op traits = []> : + Op { +} + +//===----------------------------------------------------------------------===// +// ProtonGPU Operations +//===----------------------------------------------------------------------===// + +def PTG_CircularStoreOp : PTG_Op<"circular_store", [ + MemoryEffects<[MemRead, MemWrite]> +]> { + let summary = "Store the value into a circular buffer"; + + let description = [{ + Store a metric `counter` into a circular buffer backed by the internal memory `segment`. + automatically updated. Older metric counters are dropped if the `segment` buffer is full. + }]; + + let arguments = (ins + PTG_SegmentType:$segment, + AnyTypeOf<[I32, I64]>:$counter, + UnitAttr:$isStart, + I32Attr:$scopeId + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + (`start` $isStart^):(`end`)? $segment `,` $counter attr-dict `:` + qualified(type($segment)) `,` type($counter) + }]; +} + +def PTG_ReadCounterOp : PTG_Op<"read_counter", [ + MemoryEffects<[MemRead, MemWrite]> +]> { + let summary = "Read a GPU metric counter into a scalar register"; + + let description = [{ + Read a GPU metric counter into a scalar register. + }]; + + let arguments = (ins + DefaultValuedAttr:$metric + ); + + let results = (outs AnyTypeOf<[I32, I64]>:$counter); + + let assemblyFormat = [{ + attr-dict `:` type($counter) + }]; +} + +def PTG_InitializeOp : PTG_Op<"initialize", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "Initialize the intra kernel profiler"; + + let description = [{ + Initialize the intra kernel profiler by filling the auxiliary metadata to the header. + `scratchPtr` is the base address of the profiling scratch buffer where the header is stored. + }]; + + let arguments = (ins + TT_Ptr:$scratchPtr + ); + + let assemblyFormat = "$scratchPtr attr-dict `:` qualified(type($scratchPtr))"; +} + + +def PTG_FinalizeOp : PTG_Op<"finalize", [ + MemoryEffects<[MemRead]>, // FIXME: it shouldn't always have shared memory effects + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]> +]> { + let summary = "Finalize the intra kernel profiler"; + + let description = [{ + Write back the metadata and profile to global memory. + `segment` is the segment of the internal profiling buffer that contains the profiling data. + `scratchPtr` is the address of the profiling scratch buffer. + }]; + + let arguments = (ins + PTG_SegmentType:$segment, + TT_Ptr:$scratchPtr + ); + + let assemblyFormat = [{ + $segment `,` $scratchPtr attr-dict `:` qualified(type($segment)) `,` qualified(type($scratchPtr)) + }]; +} + +def PTG_SegmentAllocOp : PTG_Op<"segment_alloc", [Pure]> { + let summary = "Get the base offset of the segment of the internal buffer"; + + let description = [{ + The internal buffer is partitioned into segments for each profiling "unit". + This operation gets the location of the memory segment in the internal buffer. + }]; + + let arguments = (ins + AnyTypeOf<[TTG_MemDescType, TT_Ptr]>:$buffer + ); + + let results = (outs PTG_SegmentType:$segment); + + let hasVerifier = 1; + + let assemblyFormat = "$buffer attr-dict `:` qualified(type($buffer)) `->` type($segment)"; +} + +def PTG_InitCtxOp : PTG_Op<"init_ctx", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "Initialize the intra kernel profiler warp-level contexts"; + + let description = [{ + Initialize the intra kernel profiler warp-level contexts for all warps in + `scratchPtr` (base address of the profiling scratch buffer). It can't be + called inside `ttg.warp_specialize`. + }]; + + let arguments = (ins + TT_Ptr:$scratchPtr + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scratchPtr attr-dict `:` qualified(type($scratchPtr)) + }]; +} + +def PTG_RestoreCtxOp : PTG_Op<"restore_ctx", [ + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]> +]> { + let summary = "Restore the current warp-level context"; + + let description = [{ + Restore the current warp context in `$segment` from + `scratchPtr` (base address of the profiling scratch buffer). + }]; + + let arguments = (ins + PTG_SegmentType:$segment, + TT_Ptr:$scratchPtr + ); + + let assemblyFormat = [{ + $segment `,` $scratchPtr attr-dict `:` qualified(type($segment)) `,` qualified(type($scratchPtr)) + }]; +} + +def PTG_SaveCtxOp : PTG_Op<"save_ctx", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "Save the current warp-level context"; + + let description = [{ + Save the current warp context from `$segment` to + `scratchPtr` (base address of the profiling scratch buffer). + }]; + + let arguments = (ins + PTG_SegmentType:$segment, + TT_Ptr:$scratchPtr + ); + + let assemblyFormat = [{ + $segment `,` $scratchPtr attr-dict `:` qualified(type($segment)) `,` qualified(type($scratchPtr)) + }]; +} + +#endif // PROTONGPU_OPS diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUTypes.td b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUTypes.td new file mode 100644 index 0000000000..d750239628 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUTypes.td @@ -0,0 +1,43 @@ +#ifndef PROTONGPU_TYPES +#define PROTONGPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUDialect.td" +include "proton/Dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.td" + +class PTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def PTG_SegmentType : PTG_TypeDef<"Segment", "segment", []> { + let summary = "A segment in the internal buffer"; + let description = [{ + The `proton_gpu.segment` type represents a segment returned by `PTG_SegmentOp`. + + Each segment is private to a profiling unit as defined by the `granularity` attribute. + The selected segments, specified by the `selectIds` attribute, collectively total `nBytes` bytes. + + When lowered to LLVM, a segment becomes a struct containing: + - `base`: pointer to the start of the internal buffer + - `segmentBase`: pointer to each segment's start in the internal buffer + - `indexPtr`: pointer to the current index within the segment + + The segment can reside in global memory or shared memory depending on the `memorySpace` attribute. + }]; + + let parameters = (ins + "int32_t":$nBytes, + "Attribute":$memorySpace, + EnumParameter:$granularity, + OptionalArrayRefParameter<"int32_t">:$selectIds + ); + + let assemblyFormat = [{ + `<` $nBytes `,` $memorySpace `,` $granularity (`,` `[` $selectIds^ `]`)? `>` + }]; +} + +#endif diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Types.h b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Types.h new file mode 100644 index 0000000000..0073b17df2 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/IR/Types.h @@ -0,0 +1,15 @@ +#ifndef PROTONGPU_IR_TYPES_H_ +#define PROTONGPU_IR_TYPES_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/OpsEnums.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Types.h.inc" + +#endif // PROTONGPU_IR_TYPES_H_ diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..3b2c8e1560 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name ProtonGPU) +add_public_tablegen_target(ProtonGPUTransformsIncGen) diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h new file mode 100644 index 0000000000..59ece41aaf --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h @@ -0,0 +1,17 @@ +#ifndef PROTONGPU_TRANSFORMS_PASSES_H_ +#define PROTONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir::triton::proton::gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h.inc" + +} // namespace mlir::triton::proton::gpu + +#endif // PROTONGPU_TRANSFORMS_PASSES_H_ diff --git a/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.td b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.td new file mode 100644 index 0000000000..42088b3ed9 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.td @@ -0,0 +1,15 @@ +#ifndef PROTONGPU_TRANSFORMS_PASSES +#define PROTONGPU_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def ScheduleBufferStorePass: Pass<"proton-schedule-buffer-store", "mlir::ModuleOp"> { + let summary = "Pass to move all Proton buffer stores to the end of the function"; + + let description = "This pass makes the measurement more accurate by moving the expensive " + "shared memory stores to the end of the measured region after the measurements."; + + let dependentDialects = ["gpu::ProtonGPUDialect"]; +} + +#endif // PROTONGPU_TRANSFORMS_PASSES diff --git a/third_party/mthreads/proton/Dialect/include/compat/PTXAsmFormat.h b/third_party/mthreads/proton/Dialect/include/compat/PTXAsmFormat.h new file mode 100644 index 0000000000..2cccf747d1 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/compat/PTXAsmFormat.h @@ -0,0 +1,347 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +struct PTXInstr; +struct PTXInstrCommon; +struct PTXInstrExecution; + +// PTXBuilder helps to manage a PTX asm program consists of one or multiple +// instructions. +// +// A helper for building an ASM program, the objective of PTXBuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// +// Usage: +// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), +// "b"(p)); +// +// PTXBuilder builder; +// auto& add = ::create(builder, ); +// add.predicate(pVal).o("lo").o("u32"); // add any suffix +// // predicate here binds %0 to pVal, pVal is a mlir::Value +// +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal +// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate +// +// To get the asm code: +// builder.dump() +// +// To get all the mlir::Value used in the PTX code, +// +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// +// To get the string containing all the constraints with "," separated, +// builder.getConstraints() // get "=r,r,k" +// +// PTXBuilder can build a PTX asm with multiple instructions, sample code: +// +// PTXBuilder builder; +// auto& mov = builder.create("mov"); +// auto& cp = builder.create("cp"); +// mov(...); +// cp(...); +// This will get a PTX code with two instructions. +// +// Similar to a C function, a declared PTXInstr instance can be launched +// multiple times with different operands, e.g. +// +// auto& mov = builder.create("mov"); +// mov(... some operands ...); +// mov(... some different operands ...); +// +// Finally, we will get a PTX code with two mov instructions. +// +// There are several derived instruction type for typical instructions, for +// example, the PtxIOInstr for ld and st instructions. +struct PTXBuilder { + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(const Operation &) = delete; + Operand(Value value, StringRef constraint) + : constraint(constraint), value(value) {} + + bool isList() const { return !value && constraint.empty(); } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + + std::string dump() const; + }; + + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formatter: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formatter = nullptr); + + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + // If the operand will be used in predicated execution, + // users may want to initialize it before use. + // Otherwise if the register is only used in the true branch or the false + // branch but not both, the register is undefined and ptxas can perform + // aggressive optimizations that may lead to incorrect results. + Operand *newOperand(StringRef constraint, bool init = false); + + // Create a new operand that is tied to a previous operand. In this case the + // asm would be permitted to write to an input register. Instead of providing + // constraint code for this operand, the constraint code of the tied operand + // is used. + Operand *newOperand(unsigned operandIndex); + + // Create a constant integer operand. + Operand *newConstantOperand(int64_t v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstraints() const; + + std::string dump() const; + + mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, + ArrayRef attrs = {}) const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + void initOperand(Operand *opr); + + // Make the operands in argArchive follow the provided \param order. + void reorderArgArchive(ArrayRef order) { + assert(order.size() == argArchive.size()); + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but + // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are + // determined by PTX code snippet passed from external. + sort(argArchive.begin(), argArchive.end(), + [&](std::unique_ptr &a, std::unique_ptr &b) { + auto ida = std::find(order.begin(), order.end(), a.get()); + auto idb = std::find(order.begin(), order.end(), b.get()); + assert(ida != order.end()); + assert(idb != order.end()); + return ida < idb; + }); + } + + friend struct PTXInstr; + friend struct PTXInstrCommon; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> instrs; + llvm::SmallVector, 4> executions; + int oprCounter{}; +}; + +// PTX instruction common interface. +// Put the generic logic for all the instructions here. +struct PTXInstrCommon { + explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {} + + using Operand = PTXBuilder::Operand; + + // clang-format off + PTXInstrExecution& operator()() { return call({}); } + PTXInstrExecution& operator()(Operand* a) { return call({a}); } + PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); } + // clang-format on + + // Set operands of this instruction. + PTXInstrExecution &operator()(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + +protected: + // "Call" the instruction with operands. + // \param oprs The operands of this instruction. + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments + // to the inline Asm without generating the operand ids(such as $0, $1) in PTX + // code. + PTXInstrExecution &call(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + + PTXBuilder *builder{}; + llvm::SmallVector instrParts; + + friend struct PTXInstrExecution; +}; + +template struct PTXInstrBase : public PTXInstrCommon { + using Operand = PTXBuilder::Operand; + + explicit PTXInstrBase(PTXBuilder *builder, const std::string &name) + : PTXInstrCommon(builder) { + o(name); + } + + // Append a suffix to the instruction. + // e.g. PTXInstr("add").o("s32") get a add.s32. + // A predicate is used to tell whether to apply the suffix, so that no if-else + // code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will + // get a `add.s32` if isS32 is true. + ConcreteT &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *static_cast(this); + } +}; + +struct PTXInstr : public PTXInstrBase { + using PTXInstrBase::PTXInstrBase; + + // Append a ".global" to the instruction. + PTXInstr &global(); + + // Append a ".shared" to the instruction. + PTXInstr &shared(); + + // Append a ".v[0-9]+" to the instruction + PTXInstr &v(int vecWidth, bool predicate = true); + + // Append a".b[0-9]+" to the instruction + PTXInstr &b(int width); +}; + +// Record the operands and context for "launching" a PtxInstr. +struct PTXInstrExecution { + using Operand = PTXBuilder::Operand; + + llvm::SmallVector argsInOrder; + + PTXInstrExecution() = default; + explicit PTXInstrExecution(PTXInstrCommon *instr, + llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} + + // Prefix a predicate to the instruction. + PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { + assert(value); + pred = instr->builder->newOperand(value, constraint); + return *this; + } + + // Prefix a predicate to the instruction, if non-null + PTXInstrExecution &maybePredicate(mlir::Value value, + StringRef constraint = "b") { + if (value) + predicate(value, constraint); + return *this; + } + + // Prefix a !predicate to the instruction. + PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { + pred = instr->builder->newOperand(value, constraint); + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; + return *this; + } + + std::string dump() const; + + SmallVector getArgList() const; + + PTXInstrCommon *instr{}; + Operand *pred{}; + bool onlyAttachMLIRArgs{}; +}; + +/// ====== Some instruction wrappers ====== +// We add the wrappers to make the usage more intuitive by avoiding mixing the +// PTX code with some trivial C++ code. + +struct PTXCpAsyncLoadInstr : PTXInstrBase { + explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, + triton::CacheModifier modifier) + : PTXInstrBase(builder, "cp.async") { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + } +}; + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/proton/Dialect/include/compat/TargetInfo.h b/third_party/mthreads/proton/Dialect/include/compat/TargetInfo.h new file mode 100644 index 0000000000..05c70c7430 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/compat/TargetInfo.h @@ -0,0 +1,86 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H + +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir::triton::NVIDIA { + +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + TargetInfo(int computeCapability, int ptxVersion) + : computeCapability(computeCapability), ptxVersion(ptxVersion) {} + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const override; + + void barrier(Location loc, RewriterBase &rewriter, + triton::gpu::AddrSpace targets) const override; + + void warpSync(Location loc, RewriterBase &rewriter) const override; + + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, Value pred, + Operation *localLoadOp = nullptr) const override; + + bool supportLdMatrix() const override { return computeCapability >= 75; } + bool supportStMatrix() const override { return computeCapability >= 90; } + bool supportLdStMatrixB8() const override { return computeCapability >= 100; } + + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const override; + + Value permute(RewriterBase &rewriter, Location loc, Value a, Value b, + Value selector) const override; + + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + ProgramIDDim axis) const override; + + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args, + ArrayRef isSigned = {}) const override; + + void printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + + ArrayRef isSigned = {}) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + + int getSharedAddressSpace() const override; + + int getAddressSpace(Attribute addressSpace) const override; + + bool supportVectorizedAtomics() const override; + + int getPtxVersion() const { return ptxVersion; } + int getComputeCapability() const { return computeCapability; } + + bool isCuda() const override { return true; } + +private: + int computeCapability; + int ptxVersion; +}; + +} // namespace mlir::triton::NVIDIA + +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H diff --git a/third_party/mthreads/proton/Dialect/include/compat/Utility.h b/third_party/mthreads/proton/Dialect/include/compat/Utility.h new file mode 100644 index 0000000000..908540e9d0 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/include/compat/Utility.h @@ -0,0 +1,71 @@ +#ifndef TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_UTILITY_H + +#include + +#include "compat/PTXAsmFormat.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "compat/TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators + +namespace mlir::triton::gpu { +class MemDescType; +} + +namespace mlir { +namespace LLVM { + +namespace NVIDIA { +class TargetInfo; + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + ProgramIDDim axis); + +/// Create a predicate with just single active thread. +Value createElectPredicate(Location loc, OpBuilder &rewriter); +Value createElectPredicateWarp0(Location loc, OpBuilder &rewriter); + +// Create bar.warp.sync +void createSyncWarp(Location loc, OpBuilder &builder); + +// Lower ldmatrix and stmatrix +LogicalResult lowerLdStMatrix( + Location loc, LinearLayout cvt, bool transpose, + SmallVector &vals, // Input for stmatrix, output for ldmatrix + Value smemBase, Value affineOffset, uint64_t maskSpanAffineOffset, + Type llvmElemTy, ConversionPatternRewriter &rewriter, + const mlir::triton::NVIDIA::TargetInfo &targetInfo); + +// Given a broadcast mask and the number of CTAs, create a mask of ones +// where for ctaId, it sets as 1's the positions that are in the same broadcast +// group +Value createTMAMulticastMask(Location loc, ConversionPatternRewriter &rewriter, + uint16_t broadcastBits); +} // namespace NVIDIA +} // namespace LLVM + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/proton/Dialect/lib/Analysis/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..a08c3c6e40 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Analysis/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(ProtonAnalysis + ScopeIdAllocation.cpp + + DEPENDS + ProtonTableGen + + LINK_LIBS PUBLIC + ProtonIR + TritonAnalysis +) diff --git a/third_party/mthreads/proton/Dialect/lib/Analysis/ScopeIdAllocation.cpp b/third_party/mthreads/proton/Dialect/lib/Analysis/ScopeIdAllocation.cpp new file mode 100644 index 0000000000..94a0a62df7 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Analysis/ScopeIdAllocation.cpp @@ -0,0 +1,403 @@ +#include "Analysis/ScopeIdAllocation.h" +#include "mlir/Analysis/TopologicalSortUtils.h" + +namespace mlir { +namespace triton::proton { + +#define DEBUG_TYPE "proton-scope-id-allocation" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using VirtualBlock = std::pair; + +struct BlockInfo { + using ScopeId = ScopeIdAllocation::ScopeId; + + llvm::DenseSet activeScopes; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + void join(const BlockInfo &other) { + for (auto &scope : other.activeScopes) { + this->activeScopes.insert(scope); + } + } + + bool contains(ScopeId scopeId) const { + return this->activeScopes.contains(scopeId); + } + + void erase(ScopeId scopeId) { this->activeScopes.erase(scopeId); } + + void insert(ScopeId scopeId) { this->activeScopes.insert(scopeId); } + + bool operator==(const BlockInfo &other) const { + return this->activeScopes == other.activeScopes; + } + + void dump() const { + auto &err = llvm::errs(); + err << "Active Scopes:\n"; + for (auto &scope : activeScopes) { + err << " " << scope << "\n"; + } + } +}; + +void ScopeIdAllocation::run() { + // We execute the following analysis stages in the order to verify if + // `proton.record` operations are well-formed and associate scope IDs for each + // pair of start/end records. + // + // 1. liveness() + // + // Pair start/end records that share a name and assign a numeric + // identifier that later passes reuse. The current implementation pairs + // each start with the nearest matching end. + // + // proton.record start @"foo" // scopeId = 0 + // … + // proton.record end @"foo" // scopeId = 0 + // … + // proton.record start @"foo" // scopeId = 1 + // … + // proton.record end @"foo" // scopeId = 1 + // + // 2. reachability() + // + // Track active scopes across CFG boundaries and surface + // malformed lifetimes once the dataflow converges. + // + // scf.if %cond { + // proton.record start @"foo" + // } + // + // Because `"foo"` never ends on the `then` branch, reachability() emits + // "The scope name 'foo' is not closed properly". + // + // scf.if %cond { + // proton.record start @"foo" + // } + // proton.record end @"foo" + // + // No diagnostic is emitted: the pass assumes the branch may execute and + // leaves semantic responsibility to the caller. + // + // 3. dominance(): + // + // (a) Ensure that each start dominates its matching end. + // + // proton.record end @"foo" + // … + // proton.record start @"foo" + // + // Because the end dominates the start, dominance() reports an error. + // + // (b) Infer parent/child scope relationships using dominance facts. + // + // proton.record start @"outer" + // scf.if %cond { + // proton.record start @"inner" + // … + // proton.record end @"inner" + // } + // proton.record end @"outer" + // + // `"outer"` dominates `"inner"`, so dominance() records + // `(innerId -> outerId)` in `scopeParentIds`. + liveness(); + reachability(); + dominance(); +} + +void ScopeIdAllocation::liveness() { + llvm::DenseMap> + nameToIdMap; + llvm::DenseMap idToOpMap; + ScopeId scopeId = 0; + + funcOp->walk([&](RecordOp recordOp) { + auto name = recordOp.getName(); + LDBG("Processing RecordOp: " << recordOp); + if (!nameToIdMap.contains(name)) { + nameToIdMap[name] = {scopeId, /*isStart=*/recordOp.getIsStart()}; + idToNameMap[scopeId] = name; + LDBG("Assigning new scope scopeId " << scopeId << " to op '" << recordOp + << "'"); + opToIdMap[recordOp] = scopeId; + idToOpMap[scopeId] = recordOp; + scopeId++; + } else { + auto &[existingId, isStart] = nameToIdMap[name]; + if (isStart == recordOp.getIsStart()) { + // Error: duplicate start or end + mlir::emitError(recordOp.getLoc(), "The scope name '") + << name << "' has duplicate " + << (recordOp.getIsStart() ? "start" : "end") << " record"; + } else { + // Matching pair found + LDBG("Found matching pair for scope name '" << name << "' with scopeId " + << existingId); + opToIdMap[recordOp] = existingId; + idToOpMap[existingId] = recordOp; + nameToIdMap.erase(name); + } + } + }); + + if (!nameToIdMap.empty()) { + for (auto &[name, idIsStartPair] : nameToIdMap) { + auto &[id, isStart] = idIsStartPair; + auto unclosedOp = idToOpMap.lookup(id); + mlir::emitError(unclosedOp.getLoc(), "The scope name '") + << name << "' is not properly closed (missing " + << (isStart ? "end" : "start") << " record)"; + } + } +} + +void ScopeIdAllocation::reachability() { + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + + std::deque virtualBlockList; + virtualBlockList.emplace_back(&funcOp.getBlocks().front(), Block::iterator()); + + while (!virtualBlockList.empty()) { + VirtualBlock virtualBlock = virtualBlockList.front(); + virtualBlockList.pop_front(); + // Evaluate the transfer function for this block starting from the cached + // input state. + auto inputBlockInfo = inputBlockInfoMap[virtualBlock]; + SmallVector successors; + Block::iterator startIt = virtualBlock.second.isValid() + ? std::next(virtualBlock.second) + : virtualBlock.first->begin(); + for (Operation &op : llvm::make_range(startIt, virtualBlock.first->end())) { + if (op.hasTrait() || + isa(op)) { + visitTerminator(&op, successors); + break; + } + if (auto recordOp = dyn_cast(&op)) { + auto scopeId = opToIdMap.lookup(recordOp); + if (recordOp.getIsStart()) { + inputBlockInfo.insert(scopeId); + } else { + inputBlockInfo.erase(scopeId); + } + } + } + // Skip successor propagation if the output state is unchanged. + if (outputBlockInfoMap.count(virtualBlock) && + inputBlockInfo == outputBlockInfoMap[virtualBlock]) { + continue; + } + // Update the current block. + outputBlockInfoMap[virtualBlock].join(inputBlockInfo); + // Propagate the new facts to successors. + for (VirtualBlock &successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[virtualBlock]); + virtualBlockList.emplace_back(successor); + } + } + + // Validate the reachability analysis results for each block. + for (auto iter : inputBlockInfoMap) { + auto &virtualBlock = iter.first; + auto inputBlockInfo = iter.second; + Block::iterator startIt = virtualBlock.second.isValid() + ? std::next(virtualBlock.second) + : virtualBlock.first->begin(); + for (Operation &op : llvm::make_range(startIt, virtualBlock.first->end())) { + if (auto recordOp = dyn_cast(&op)) { + auto scopeId = opToIdMap.lookup(recordOp); + auto name = idToNameMap.lookup(scopeId); + if (recordOp.getIsStart()) { + if (inputBlockInfo.contains(scopeId)) { + mlir::emitError(recordOp.getLoc(), "The scope name '") + << name << "' is started without being closed"; + } + inputBlockInfo.insert(scopeId); + } else { + if (inputBlockInfo.contains(scopeId)) { + inputBlockInfo.erase(scopeId); + } else { + mlir::emitError(recordOp.getLoc(), "The scope name '") + << name << "' is closed without being opened"; + } + } + } + } + } +} + +void ScopeIdAllocation::dominance() { + // Stage 3: derive scope parentage and verify dominance constraints. + mlir::DominanceInfo domInfo(funcOp); + mlir::PostDominanceInfo postDomInfo(funcOp); + llvm::DenseMap startRecordMap; + llvm::DenseMap endRecordMap; + funcOp->walk([&](RecordOp recordOp) { + auto scopeId = opToIdMap.lookup(recordOp); + if (recordOp.getIsStart()) + startRecordMap[scopeId] = recordOp.getOperation(); + else + endRecordMap[scopeId] = recordOp.getOperation(); + }); + + for (auto &[scopeId, startOp] : startRecordMap) { + auto *endOp = endRecordMap.lookup(scopeId); + if (!endOp) + continue; + if (domInfo.dominates(endOp, startOp)) { + auto name = idToNameMap.lookup(scopeId); + mlir::emitError(endOp->getLoc(), "The scope name '") + << name << "' has end record that dominates its start record"; + } + } + + llvm::SetVector startRecordOps; + for (auto &[scopeId, startOp] : startRecordMap) { + startRecordOps.insert(startOp); + } + auto sortedStartRecordOps = mlir::topologicalSort(startRecordOps); + for (int i = 0; i < sortedStartRecordOps.size(); ++i) { + auto *startOp = sortedStartRecordOps[i]; + auto scopeId = opToIdMap.lookup(startOp); + auto endOp = endRecordMap.lookup(scopeId); + for (int j = i - 1; j >= 0; --j) { + auto *parentStartOp = sortedStartRecordOps[j]; + auto parentScopeId = opToIdMap.lookup(parentStartOp); + auto parentEndOp = endRecordMap.lookup(parentScopeId); + if (domInfo.dominates(parentStartOp, startOp) && + postDomInfo.postDominates(parentEndOp, endOp)) { + auto parentId = opToIdMap.lookup(parentStartOp); + auto childId = opToIdMap.lookup(startOp); + scopeParentIds.push_back({childId, parentId}); + break; + } + } + } +} + +void ScopeIdAllocation::visitTerminator(Operation *op, + SmallVector &successors) { + if (isa(op)) { + // Collect the block successors of the branch. + for (Block *successor : op->getSuccessors()) + successors.emplace_back(successor, Block::iterator()); + return; + } + + if (auto br = dyn_cast(op)) { + // Query successors of an op-with-regions. The op can branch to region entry + // blocks or to the continuation after itself. + SmallVector regions; + br.getSuccessorRegions(RegionBranchPoint::parent(), regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + successors.emplace_back(br->getBlock(), br->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // FIXME: `ReturnLike` adds `RegionBranchTerminatorOpInterface` for some + // reason. Check that the parent is actually a `RegionBranchOpInterface`. + auto br = dyn_cast(op); + if (br && isa(br->getParentOp())) { + // Region branch terminators can jump to another region belonging to the + // parent operation or to the parent continuation. + SmallVector operands(br->getNumOperands()); + SmallVector regions; + br.getSuccessorRegions(operands, regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + Operation *parent = br->getParentOp(); + successors.emplace_back(parent->getBlock(), parent->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // Otherwise, it could be a return-like op. + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +ModuleScopeIdAllocation::ModuleScopeIdAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) { + ScopeIdAllocation::ScopeId funcScopeId = 0; + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + if (funcMap.contains(funcOp)) { + return; + } + auto iter = funcMap.try_emplace(funcOp, ScopeIdAllocation(funcOp)); + funcScopeIdMap[funcOp] = funcScopeId; + funcScopeId += iter.first->second.getNumScopes(); + }); + // Precompute per-function scope id mappings + for (auto [funcOp, offset] : funcScopeIdMap) { + // Names + auto names = funcMap.lookup(funcOp).getScopeIdNames(); + for (auto &p : names) + p.first += offset; + scopeIdNames[funcOp] = std::move(names); + // Parents + auto parents = funcMap.lookup(funcOp).getScopeIdParents(); + for (auto &p : parents) { + p.first += offset; + p.second += offset; + } + scopeIdParents[funcOp] = std::move(parents); + } +} + +ScopeIdAllocation::ScopeId +ModuleScopeIdAllocation::getOpScopeId(Operation *op) const { + auto funcOp = op->getParentOfType(); + auto funcOffset = funcScopeIdMap.lookup(funcOp); + return funcMap.lookup(funcOp).getOpScopeId(op) + funcOffset; +} + +ScopeIdAllocation::ScopeIdName +ModuleScopeIdAllocation::getScopeIdNames(triton::FuncOp funcOp) const { + return scopeIdNames.lookup(funcOp); +} + +ScopeIdAllocation::ScopeIdName +ModuleScopeIdAllocation::getScopeIdNames() const { + ScopeIdAllocation::ScopeIdName combined; + for (auto &entry : scopeIdNames) + combined.insert(combined.end(), entry.second.begin(), entry.second.end()); + return combined; +} + +ScopeIdAllocation::ScopeIdParent +ModuleScopeIdAllocation::getScopeIdParents(triton::FuncOp funcOp) const { + return scopeIdParents.lookup(funcOp); +} + +ScopeIdAllocation::ScopeIdParent +ModuleScopeIdAllocation::getScopeIdParents() const { + ScopeIdAllocation::ScopeIdParent combined; + for (auto &entry : scopeIdParents) + combined.insert(combined.end(), entry.second.begin(), entry.second.end()); + return combined; +} + +} // namespace triton::proton +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/CMakeLists.txt new file mode 100644 index 0000000000..d90f2889d3 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(ProtonToProtonGPU) +add_subdirectory(ProtonGPUToLLVM) +# add_subdirectory(compat) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..adc6d78575 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Proton) +add_subdirectory(ProtonGPU) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..7fa2ff43cd --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen +) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 0000000000..b94464fd20 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,34 @@ +#include "Dialect/Proton/IR/Dialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/InliningUtils.h" + +#include "Dialect/Proton/IR/Dialect.cpp.inc" + +namespace mlir::triton::proton { +struct ProtonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } +}; + +void ProtonDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); + addInterfaces(); +} +} // namespace mlir::triton::proton diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 0000000000..fda836a1c4 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,12 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" + +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/CMakeLists.txt new file mode 100644 index 0000000000..d7b29312a0 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(ProtonGPUIR + Dialect.cpp + Ops.cpp + Types.cpp + + DEPENDS + ProtonGPUTableGen + ProtonGPUAttrDefsIncGen + ProtonGPUTypesIncGen + + LINK_LIBS PUBLIC + TritonGPUIR + ProtonIR +) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Dialect.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Dialect.cpp new file mode 100644 index 0000000000..356709cdb9 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Dialect.cpp @@ -0,0 +1,33 @@ +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "Dialect/ProtonGPU/IR/Dialect.cpp.inc" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_ATTRDEF_CLASSES +#include "proton/Dialect/include/Dialect/ProtonGPU/IR/AttrDefs.cpp.inc" + +using namespace mlir; + +const int mlir::triton::proton::gpu::getBytesPerClockEntry() { return 8; } +const int mlir::triton::proton::gpu::getCircularHeaderSize() { return 40; } + +void mlir::triton::proton::gpu::ProtonGPUDialect::initialize() { + registerTypes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/ProtonGPU/IR/AttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "Dialect/ProtonGPU/IR/Ops.cpp.inc" + >(); +} + +const int mlir::triton::proton::gpu::getTotalNumWarps(ModuleOp mod) { + int numWarps = mlir::triton::gpu::lookupNumWarps(mod); + if (auto totalNumWarps = + mod->getAttrOfType("ttg.total-num-warps")) + numWarps = totalNumWarps.getInt(); + return numWarps; +} diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Ops.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Ops.cpp new file mode 100644 index 0000000000..a8e6d8650e --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Ops.cpp @@ -0,0 +1,65 @@ +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "Dialect/ProtonGPU/IR/Ops.cpp.inc" + +#include "Dialect/ProtonGPU/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { +namespace gpu { + +// -- CircularRecordOp -- +LogicalResult CircularStoreOp::verify() { + auto scopeId = getScopeId(); + auto segmentType = getSegment().getType(); + auto granularity = segmentType.getGranularity(); + auto selectedIds = segmentType.getSelectIds(); + auto bufferSizeInBytes = segmentType.getNBytes(); + auto mod = getOperation()->getParentOfType(); + + int numWarps = getTotalNumWarps(mod); + + int segmentNum = selectedIds.empty() ? numWarps : selectedIds.size(); + if (!llvm::isPowerOf2_32(bufferSizeInBytes / segmentNum)) + return emitOpError("profiling buffer segment size must be power of 2"); + + if (scopeId < 0 || scopeId > 255) + return emitOpError("scope id must be in [0, 255]"); + + return success(); +} + +// -- SegmentAllocOp -- +LogicalResult SegmentAllocOp::verify() { + auto segmentType = getSegment().getType(); + auto granularity = segmentType.getGranularity(); + auto selectIds = segmentType.getSelectIds(); + if (granularity != Granularity::WARP && selectIds.size()) { + return emitOpError( + "only warp granularity supports non-empty selectIds for now"); + } + return success(); +} + +// -- InitCtxOp -- +LogicalResult InitCtxOp::verify() { + if (getOperation()->getParentOfType()) + return emitOpError( + "can't initialize proton context in a warp specialized op"); + return success(); +} + +} // namespace gpu +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Types.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Types.cpp new file mode 100644 index 0000000000..236537eb35 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/IR/Types.cpp @@ -0,0 +1,23 @@ +#include "Dialect/ProtonGPU/IR/Types.h" + +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::proton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "Dialect/ProtonGPU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// ProtonGPU Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::proton::gpu::ProtonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "Dialect/ProtonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..e3280ba738 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(ProtonGPUTransforms + ProtonGPUTransformsPass.cpp + + DEPENDS + ProtonGPUTransformsIncGen + LINK_LIBS PUBLIC + ProtonGPUIR +) diff --git a/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/ProtonGPUTransformsPass.cpp b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/ProtonGPUTransformsPass.cpp new file mode 100644 index 0000000000..ed3323e5ad --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/Dialect/ProtonGPU/Transforms/ProtonGPUTransformsPass.cpp @@ -0,0 +1,52 @@ +#include "Dialect/ProtonGPU/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "Dialect/ProtonGPU/IR/Dialect.h" + +namespace mlir::triton::proton::gpu { + +#define GEN_PASS_DEF_SCHEDULEBUFFERSTOREPASS +#include "Dialect/ProtonGPU/Transforms/Passes.h.inc" + +struct ScheduleBufferStorePass + : public impl::ScheduleBufferStorePassBase { + + using impl::ScheduleBufferStorePassBase< + ScheduleBufferStorePass>::ScheduleBufferStorePassBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + MLIRContext *context = m.getContext(); + OpBuilder builder(context); + + // TODO(srir): Add support for non-inline kernels + FuncOp func = *m.getOps().begin(); + auto startStoreList = llvm::SmallVector(); + auto endStoreMap = llvm::SmallDenseMap(); + + func.walk([&](CircularStoreOp store) { + if (store.getIsStart()) + startStoreList.push_back(store); + else + endStoreMap[store.getScopeId()] = store; + }); + + for (auto store : startStoreList) { + int scopeId = store.getScopeId(); + auto endStore = endStoreMap[scopeId]; + if (!endStore) { + mlir::emitError(func.getLoc(), "proton end store not found"); + signalPassFailure(); + return; + } + builder.setInsertionPoint(endStore); + builder.clone(*store); + store->erase(); + } + } +}; + +} // namespace mlir::triton::proton::gpu diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonGlobalScratchBuffer.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonGlobalScratchBuffer.cpp new file mode 100644 index 0000000000..f40832054b --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonGlobalScratchBuffer.cpp @@ -0,0 +1,51 @@ +#include "Conversion/ProtonGPUToLLVM/Passes.h" +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::proton::gpu { + +#define GEN_PASS_DEF_ALLOCATEPROTONGLOBALSCRATCHBUFFERPASS +#include "Conversion/ProtonGPUToLLVM/Passes.h.inc" + +struct AllocateProtonGlobalScratchBufferPass + : public impl::AllocateProtonGlobalScratchBufferPassBase< + AllocateProtonGlobalScratchBufferPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + auto funcOps = triton::proton::gpu::getTritonFunctions(mod); + assert(funcOps.size() == 1 && "Expected exactly one funcOp"); + + int32_t cumulativeMemorySize = 0; // bytes + std::vector alignments; + + funcOps[0].walk([&](triton::gpu::GlobalScratchAllocOp op) { + if (op.getBackend() != "proton") + return; + int offset = llvm::alignTo(cumulativeMemorySize, + proton::gpu::getBytesPerClockEntry()); + op->setAttr("offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + cumulativeMemorySize += op.getNbytes(); + alignments.push_back(op.getAlignment()); + }); + if (alignments.empty()) + return; + + bool allAlignmentsEqual = std::equal(alignments.begin() + 1, + alignments.end(), alignments.begin()); + assert(allAlignmentsEqual && + "all global scratch buffer alignment values must be the same"); + mod->setAttr("ttg.profile_scratch_memory_size", + builder.getI32IntegerAttr(cumulativeMemorySize)); + mod->setAttr("ttg.profile_scratch_memory_alignment", + builder.getI32IntegerAttr(alignments.front())); + } +}; + +} // namespace mlir::triton::proton::gpu diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonSharedMemory.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonSharedMemory.cpp new file mode 100644 index 0000000000..87dfe22aaa --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/AllocateProtonSharedMemory.cpp @@ -0,0 +1,60 @@ +#include "Conversion/ProtonGPUToLLVM/Passes.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::proton::gpu { + +#define GEN_PASS_DEF_ALLOCATEPROTONSHAREDMEMORYPASS +#include "Conversion/ProtonGPUToLLVM/Passes.h.inc" + +struct AllocateProtonSharedMemoryPass + : public impl::AllocateProtonSharedMemoryPassBase< + AllocateProtonSharedMemoryPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + int sharedMemUsed = 0; + if (mod->hasAttr("ttg.shared")) + sharedMemUsed = + mod->getAttrOfType("ttg.shared").getInt(); + + assert(llvm::range_size(mod.getOps()) == 1); + FuncOp func = *mod.getOps().begin(); + + int totalSharedMemSize = 0; + int count = 0; + func.walk([&](triton::gpu::LocalAllocOp alloc) { + // We ignore the shared memory allocations that have been allocated by the + // triton conversion pass. + if (!alloc->hasAttr("allocation.offset")) { + int offset = + llvm::alignTo(sharedMemUsed, proton::gpu::getBytesPerClockEntry()); + alloc->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + // Compute the proton buffer size in bytes. + auto memDescTy = + mlir::cast(alloc.getResult().getType()); + int bufferSizeInBytes = + mlir::ShapedType::getNumElements(memDescTy.getShape()) * + memDescTy.getElementType().getIntOrFloatBitWidth() / 8; + + totalSharedMemSize = offset + bufferSizeInBytes; + count++; + } + }); + + if (count == 0) { + totalSharedMemSize = sharedMemUsed; + } + + mod->setAttr("ttg.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + totalSharedMemSize)); + } +}; + +} // namespace mlir::triton::proton::gpu diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..bf8e37415f --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(ProtonGPUToLLVM + AllocateProtonGlobalScratchBuffer.cpp + AllocateProtonSharedMemory.cpp + PatternProtonGPUOpToLLVM.cpp + Utility.cpp + + DEPENDS + ProtonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + ProtonIR + ProtonGPUIR + ProtonAnalysis +) + +add_subdirectory(ProtonNvidiaGPUToLLVM) +if(TRITON_ENABLE_AMD) + add_subdirectory(ProtonAMDGPUToLLVM) +endif() diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.cpp new file mode 100644 index 0000000000..e689fe70a3 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.cpp @@ -0,0 +1,827 @@ +#include "Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/TargetInfoBase.h" +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "compat/PTXAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton { +namespace proton::gpu { + +namespace { + +Value getLinearId(Location loc, ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Note: + // 1. We compute use i64 data type to compute and then truncate to i32 + // to support various backend intrinsics (e.g. amd). + // 2. We avoid using the targetInfo's programId() because of its coupling + // with cluster id in Nvidia TritonGPU's llvm lowering. + Value pidX = arith::IndexCastOp::create( + rewriter, loc, i64_ty, + mlir::gpu::BlockIdOp::create(rewriter, loc, mlir::gpu::Dimension::x)); + Value pidY = arith::IndexCastOp::create( + rewriter, loc, i64_ty, + mlir::gpu::BlockIdOp::create(rewriter, loc, mlir::gpu::Dimension::y)); + Value pidZ = arith::IndexCastOp::create( + rewriter, loc, i64_ty, + mlir::gpu::BlockIdOp::create(rewriter, loc, mlir::gpu::Dimension::z)); + + Value gridDimX = arith::IndexCastOp::create( + rewriter, loc, i64_ty, + ::mlir::gpu::GridDimOp::create(rewriter, loc, mlir::gpu::Dimension::x)); + Value gridDimY = arith::IndexCastOp::create( + rewriter, loc, i64_ty, + ::mlir::gpu::GridDimOp::create(rewriter, loc, mlir::gpu::Dimension::y)); + Value linearId = + b.trunc(i32_ty, b.add(b.add(pidX, b.mul(pidY, gridDimX)), + b.mul(pidZ, b.mul(gridDimX, gridDimY)))); + return linearId; +} + +struct ReadCounterOpConversion + : public ConvertOpToLLVMPattern { + explicit ReadCounterOpConversion( + LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::ReadCounterOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool isClock64 = false; + auto intType = mlir::cast(op.getResult().getType()); + isClock64 = intType.getWidth() == 64; + Value clock = targetInfo.clock(rewriter, op.getLoc(), isClock64); + rewriter.replaceOp(op, clock); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct InitializeOpConversion + : public ConvertOpToLLVMPattern { + explicit InitializeOpConversion(LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::InitializeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Value scratchPtr = adaptor.getScratchPtr(); + auto scratchPtrTy = mlir::cast(scratchPtr.getType()); + + // Header layout (total: circularHeaderSize bytes) + // +-------------------------------+ 0 + // | preamble (1 word) | + // +-------------------------------+ 1 + // | program id (1 word) | + // +-------------------------------+ 2 + // | hw id (1 word) | + // +-------------------------------+ 3 + // | buffer size (1 word) | + // +-------------------------------+ 4 + // | init time | + // | (2 words) | + // +-------------------------------+ 6 + // | pre-final time | + // | (2 words) | + // +-------------------------------+ 8 + // | post-final time | + // | (2 words) | + // +-------------------------------+ 10 + + Value threadId = getThreadId(rewriter, loc); + Value isFirstThread = b.icmp_eq(threadId, b.i32_val(0)); + + Block *prevBlock = op->getBlock(); + + // Add the 'if' block. + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + + // Write back 'preamble'. + Value preamble = b.i32_val(0xdeadbeef); + Value gmemPreambleOffset = b.i32_val(0); + Value gmemPreamblePtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemPreambleOffset); + b.store(preamble, gmemPreamblePtr); + + // Write back 'program id'. + Value gmemPidOffset = b.i32_val(1); + Value gmemPidPtr = b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemPidOffset); + Value pid = getLinearId(loc, rewriter); + b.store(pid, gmemPidPtr); + + // Write back 'hw id'. + Value gmemHwidOffset = b.i32_val(2); + Value gmemHwidPtr = b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemHwidOffset); + Value hwid = targetInfo.processorId(rewriter, loc); + b.store(hwid, gmemHwidPtr); + + // Write back 'init time'. + Value gmemInitTimeOffset = b.i32_val(4); + Value gmemInitTimePtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemInitTimeOffset); + Value initTime = targetInfo.globalTime(rewriter, loc); + b.store(initTime, gmemInitTimePtr); + + // Add the 'else' block and the condition. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, isFirstThread, ifBlock, thenBlock); + rewriter.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(rewriter, loc, thenBlock); + + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct FinalizeOpConversion + : public ConvertOpToLLVMPattern { + explicit FinalizeOpConversion(LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::FinalizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto segmentObj = + LLVM::SegmentObject::fromStruct(loc, adaptor.getSegment(), rewriter); + Value scratchPtr = adaptor.getScratchPtr(); + + auto mod = op.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + const int bytesPerEntry = proton::gpu::getBytesPerClockEntry(); + const int wordsPerEntry = bytesPerEntry / 4; // 1 word = 4 bytes + + int numWarps = getTotalNumWarps(mod); + + Value threadId = getRawThreadId(rewriter, loc); + Value threadsPerWarp = + b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = b.udiv(threadId, threadsPerWarp); + Value laneId = b.urem(threadId, threadsPerWarp); + Value isWarpFirstThread = b.icmp_eq(laneId, b.i32_val(0)); + Value isBlockFirstThread = b.icmp_eq(threadId, b.i32_val(0)); + auto segmentType = op.getSegment().getType(); + const int bufferSizeInWords = segmentType.getNBytes() / 4; + const int circularHeaderWordSize = proton::gpu::getCircularHeaderSize() / 4; + + // Circular strategy memory layout (total: allocprofileScratchSize bytes) + // +---------------------------------------+ + // | header (circularHeaderSize bytes) | + // +---------------------------------------+ + // | warp index (4 bytes x numWarps) | + // +---------------------------------------+ + // | profiled data (allocBufferSize bytes) | + // +---------------------------------------+ + const int metadataWordSize = circularHeaderWordSize + numWarps; + auto selectIds = segmentType.getSelectIds(); + bool hasSelectIds = !selectIds.empty(); + int activeWarpCount = hasSelectIds ? selectIds.size() : numWarps; + const int segmentWordSize = bufferSizeInWords / activeWarpCount; + auto scratchPtrTy = mlir::cast(scratchPtr.getType()); + auto segmentBaseTy = + mlir::cast(segmentObj.base.getType()); + + // Control-flow outline: + // prevBlock + // └─ condbr (block leader?) -> leaderBlock / continuation + // leaderBlock + // └─ ...body... + // └─ br continuation + // continuation + // └─ condbr (warp leader?) -> storeBlock / afterStore + // storeBlock + // └─ ...store warp index... + // └─ br afterStore + // afterStore + // └─ (optional shared mem copy) + Block *continuation = + emitBlockLeaderPrologue(op, isBlockFirstThread, scratchPtr, + scratchPtrTy, bufferSizeInWords, rewriter); + continuation = emitWarpIndexWriteback( + op, continuation, isWarpFirstThread, warpId, scratchPtr, scratchPtrTy, + segmentObj, circularHeaderWordSize, rewriter); + if (segmentBaseTy.getAddressSpace() == 3) { + // shared memory + continuation = emitWarpCopySection( + op, continuation, laneId, threadsPerWarp, scratchPtr, scratchPtrTy, + segmentObj, metadataWordSize, wordsPerEntry, segmentWordSize, + circularHeaderWordSize, segmentType.getMemorySpace(), rewriter); + } + emitBlockLeaderEpilogue(op, continuation, isBlockFirstThread, scratchPtr, + scratchPtrTy, rewriter); + rewriter.eraseOp(op); + return success(); + } + +private: + Block *emitBlockLeaderPrologue(mlir::triton::proton::gpu::FinalizeOp op, + Value isBlockFirstThread, Value scratchPtr, + LLVM::LLVMPointerType scratchPtrTy, + int bufferSizeInWords, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Control-flow outline: + // prevBlock + // └─ condbr (block leader?) -> leaderBlock / continuation + // leaderBlock + // └─ ...body... + // └─ br continuation + // continuation + Block *prevBlock = op->getBlock(); + Block *continuation = rewriter.splitBlock(prevBlock, op->getIterator()); + Block *leaderBlock = rewriter.createBlock(prevBlock->getParent(), + Region::iterator(continuation)); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, isBlockFirstThread, leaderBlock, + continuation); + rewriter.setInsertionPointToStart(leaderBlock); + + Value gmemBufSizeOffset = b.i32_val(3); + Value gmemBufSizePtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemBufSizeOffset); + Value bufferCapacityInBytes = b.i32_val(bufferSizeInWords * 4); + b.store(bufferCapacityInBytes, gmemBufSizePtr); + + Value gmemPreFinalTimeOffset = b.i32_val(6); + Value gmemPreFinalTimePtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemPreFinalTimeOffset); + Value preFinalTime = targetInfo.globalTime(rewriter, loc); + b.store(preFinalTime, gmemPreFinalTimePtr); + + cf::BranchOp::create(rewriter, loc, continuation); + rewriter.setInsertionPointToStart(continuation); + return continuation; + } + + Block *emitWarpIndexWriteback(mlir::triton::proton::gpu::FinalizeOp op, + Block *continuation, Value isWarpFirstThread, + Value warpId, Value scratchPtr, + LLVM::LLVMPointerType scratchPtrTy, + const LLVM::SegmentObject &segmentObj, + int circularHeaderWordSize, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Block *afterStore = rewriter.splitBlock(continuation, op->getIterator()); + Block *storeBlock = rewriter.createBlock(op->getParentRegion(), + Region::iterator(afterStore)); + + rewriter.setInsertionPointToEnd(continuation); + cf::CondBranchOp::create(rewriter, loc, isWarpFirstThread, storeBlock, + afterStore); + + rewriter.setInsertionPointToStart(storeBlock); + Value warpIndexOffset = b.add(warpId, b.i32_val(circularHeaderWordSize)); + Value gmemWarpIndexPtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, warpIndexOffset); + Value indexForStore = b.load(i32_ty, segmentObj.indexPtr); + b.store(indexForStore, gmemWarpIndexPtr); + cf::BranchOp::create(rewriter, loc, afterStore); + + rewriter.setInsertionPointToStart(afterStore); + return afterStore; + } + + Block *emitWarpCopySection(mlir::triton::proton::gpu::FinalizeOp op, + Block *continuation, Value laneId, + Value threadsPerWarp, Value scratchPtr, + LLVM::LLVMPointerType scratchPtrTy, + const LLVM::SegmentObject &segmentObj, + int metadataWordSize, int wordsPerEntry, + int segmentWordSize, int circularHeaderWordSize, + Attribute memSpace, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Control-flow outline: + // continuation + // └─ br copyBlock + // copyBlock + // └─ condbr (thread can copy?) -> loopHeader / exitBlock + // loopHeader + // └─ condbr (idx < loopLimit) -> loopBody / exitBlock + // loopBody + // └─ br loopHeader (idx += threadStride) + // exitBlock + Block *copyBlock = rewriter.splitBlock(continuation, op->getIterator()); + Block *exitBlock = rewriter.splitBlock(copyBlock, op->getIterator()); + Block *loopHeader = rewriter.createBlock( + op->getParentRegion(), Region::iterator(exitBlock), {i32_ty}, {loc}); + Block *loopBody = rewriter.createBlock( + op->getParentRegion(), Region::iterator(exitBlock), {i32_ty}, {loc}); + + rewriter.setInsertionPointToEnd(continuation); + cf::BranchOp::create(rewriter, loc, copyBlock); + + rewriter.setInsertionPointToStart(copyBlock); + Value segmentBase = segmentObj.segmentBase; + Value index = b.load(i32_ty, segmentObj.indexPtr); + auto bufferBaseType = segmentObj.base.getType(); + Value maxBufferWords = b.i32_val(segmentWordSize); + Value effectiveBufferWords = + b.select(b.icmp_slt(index, maxBufferWords), index, maxBufferWords); + Value hasSegment = b.icmp_sge(segmentBase, b.i32_val(0)); + Value hasData = b.icmp_sge(effectiveBufferWords, b.i32_val(wordsPerEntry)); + Value shouldCopy = b.and_(hasSegment, hasData); + Value threadStride = b.mul(threadsPerWarp, b.i32_val(wordsPerEntry)); + Value loopUpperBound = + b.sub(effectiveBufferWords, b.i32_val(wordsPerEntry)); + // Each lane copies records in a warp-strided pattern. + Value laneInitIdx = b.mul(laneId, b.i32_val(wordsPerEntry)); + Value laneWithinBounds = b.icmp_sle(laneInitIdx, loopUpperBound); + Value threadShouldCopy = b.and_(shouldCopy, laneWithinBounds); + + auto &tritonTargetInfo = targetInfo.getTritonTargetInfo(); + auto copyWord = [&](Value bufOffset, Value gmemOffset, Attribute memory) { + // Load the value from buffer and store it to global memory. + Value ptr = b.gep(bufferBaseType, i32_ty, segmentObj.base, bufOffset); + Value load; + if (mlir::isa(memory)) { + load = tritonTargetInfo.loadShared(rewriter, loc, ptr, i32_ty, + b.true_val()); + } else { + llvm::report_fatal_error( + "unsupported memory space buffer in finalize copy"); + } + + Value gmemPtr = b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemOffset); + b.store(load, gmemPtr); + }; + + // Write back the data. + cf::CondBranchOp::create(rewriter, loc, threadShouldCopy, loopHeader, + ValueRange{laneInitIdx}, exitBlock, ValueRange{}); + + rewriter.setInsertionPointToStart(loopHeader); + BlockArgument headerIdx = loopHeader->getArgument(0); + Value continueLoop = b.icmp_sle(headerIdx, loopUpperBound); + cf::CondBranchOp::create(rewriter, loc, continueLoop, loopBody, + ValueRange{headerIdx}, exitBlock, ValueRange{}); + + rewriter.setInsertionPointToStart(loopBody); + BlockArgument bodyIdx = loopBody->getArgument(0); + Value bufTagOffset = b.add(segmentBase, bodyIdx); + Value bufCounterOffset = b.add(bufTagOffset, b.i32_val(1)); + Value gmemBaseOffset = b.add(b.i32_val(metadataWordSize), segmentBase); + Value gmemWbTagOffset = b.add(gmemBaseOffset, bodyIdx); + Value gmemWbCounterOffset = b.add(gmemWbTagOffset, b.i32_val(1)); + copyWord(bufTagOffset, gmemWbTagOffset, memSpace); + copyWord(bufCounterOffset, gmemWbCounterOffset, memSpace); + Value nextIdx = b.add(bodyIdx, threadStride); + cf::BranchOp::create(rewriter, loc, loopHeader, ValueRange{nextIdx}); + + rewriter.setInsertionPointToStart(exitBlock); + return exitBlock; + } + + void emitBlockLeaderEpilogue(mlir::triton::proton::gpu::FinalizeOp op, + Block *thenBlock, Value isBlockFirstThread, + Value scratchPtr, + LLVM::LLVMPointerType scratchPtrTy, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Control-flow outline: + // thenBlock + // └─ condbr (block leader?) -> leaderBlock / continuation + // leaderBlock + // └─ ...body... + // └─ br continuation + // continuation + Block *continuation = rewriter.splitBlock(thenBlock, op->getIterator()); + Block *leaderBlock = rewriter.createBlock(thenBlock->getParent(), + Region::iterator(continuation)); + rewriter.setInsertionPointToEnd(thenBlock); + cf::CondBranchOp::create(rewriter, loc, isBlockFirstThread, leaderBlock, + continuation); + rewriter.setInsertionPointToStart(leaderBlock); + + Value gmemPostFinalTimeOffset = b.i32_val(8); + Value gmemPostFinalTimePtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, gmemPostFinalTimeOffset); + Value postFinalTime = targetInfo.globalTime(rewriter, loc); + b.store(postFinalTime, gmemPostFinalTimePtr); + cf::BranchOp::create(rewriter, loc, continuation); + rewriter.setInsertionPointToStart(continuation); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct SegmentAllocOpConversion + : public ConvertOpToLLVMPattern { + explicit SegmentAllocOpConversion( + LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::SegmentAllocOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto mod = op.getOperation()->getParentOfType(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + int numWarps = getTotalNumWarps(mod); + + auto segmentType = op.getResult().getType(); + auto granularity = segmentType.getGranularity(); + auto selectIds = segmentType.getSelectIds(); + bool isAllIds = selectIds.empty() ? true : false; + + if (granularity != proton::gpu::Granularity::WARP) { + mlir::emitError(loc, "granularity must be warp for now"); + return failure(); + } + + Value curThreadId = getRawThreadId(rewriter, loc); + + Value threadsPerWarp = + b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value curWarpId = b.udiv(curThreadId, threadsPerWarp); + const int bufferSizeInBytes = op.getSegment().getType().getNBytes(); + + // Specialize the segment base address calculation might bring a few cycles + // saving per record measurement overhead. + Value segmentBase; + if (isAllIds) { + if (granularity == proton::gpu::Granularity::WARP) + segmentBase = + allWarpSegmentAlloc(b, curWarpId, numWarps, bufferSizeInBytes); + else + llvm::report_fatal_error( + "segment address specialization not implemented yet"); + } else { + segmentBase = + defaultSegmentAlloc(b, curWarpId, selectIds, bufferSizeInBytes); + } + + Value buffer = adaptor.getBuffer(); + Value bufferBase; + if (isa(buffer.getType())) { + bufferBase = buffer; + } else { + Type bufferBaseTy = + mlir::cast(buffer.getType()).getBody()[0]; + bufferBase = b.extract_val(bufferBaseTy, buffer, 0); + } + auto indexPtrTy = + ptr_ty(rewriter.getContext(), targetInfo.getIndexPtrAddrSpace()); + auto indexPtr = LLVM::AllocaOp::create(rewriter, loc, indexPtrTy, i32_ty, + b.i32_val(1), /*alignment=*/0); + b.store(b.i32_val(0), indexPtr); + + auto segmentObj = LLVM::SegmentObject(bufferBase, segmentBase, indexPtr); + auto llvmStruct = segmentObj.getStruct(loc, rewriter); + rewriter.replaceOp(op, llvmStruct); + return success(); + } + +private: + Value defaultSegmentAlloc(TritonLLVMOpBuilder &b, Value curWarpId, + llvm::ArrayRef selectedIds, + int bufferSize) const { + const int segmentWordSize = bufferSize / selectedIds.size() / 4; + int warpSegmentAlloc = 0; + Value segmentAlloc = b.i32_val(-1); + for (int warpId : selectedIds) { + segmentAlloc = b.select(b.icmp_eq(curWarpId, b.i32_val(warpId)), + b.i32_val(warpSegmentAlloc), segmentAlloc); + warpSegmentAlloc += segmentWordSize; + } + return segmentAlloc; + } + + Value allWarpSegmentAlloc(TritonLLVMOpBuilder &b, Value curWarpId, + int numWarps, int bufferSize) const { + const int segmentWordSize = bufferSize / numWarps / 4; + return b.mul(curWarpId, b.i32_val(segmentWordSize)); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + explicit GlobalScratchAllocOpConversion( + LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) { + return failure(); + } + + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto ptrTy = mlir::LLVM::LLVMPointerType::get(ctx, 1); + assert(op->hasAttr("offset")); + size_t offset = + cast(op->getAttr("offset")).getValue().getZExtValue(); + + Value allocOffset = b.i32_val(offset); + + // See NOTE: [Additional Function Arguments] + if (!LLVM::isKernel(funcOp)) { + // Base for this function + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() + + kProfileScratchBufferOffset); + + Value ptr = b.gep(ptrTy, i8_ty, gmemBase, allocOffset); + rewriter.replaceOp(op, ptr); + return success(); + } + + // Base for entire kernel + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() + + kProfileScratchBufferOffset); + auto allocSizeAttr = mod.getOperation()->getAttrOfType( + "ttg.profile_scratch_memory_size"); + assert(allocSizeAttr); + + Value linearId = getLinearId(loc, rewriter); + + auto allocSize = allocSizeAttr.getValue().getZExtValue(); + Value gmemOffset = + b.add(allocOffset, b.mul(linearId, b.i32_val(allocSize))); + + auto ptr = b.gep(ptrTy, i8_ty, gmemBase, gmemOffset); + + rewriter.replaceOp(op, ptr); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct InitCtxOpConversion + : public ConvertOpToLLVMPattern { + explicit InitCtxOpConversion(LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::InitCtxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value scratchPtr = adaptor.getScratchPtr(); + auto scratchPtrTy = mlir::cast(scratchPtr.getType()); + + auto mod = op.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + int numWarps = getTotalNumWarps(mod); + + // InitCtxOp can only be called in the master warps, so using `getThreadId` + // is fine. + Value threadId = getThreadId(rewriter, loc); + Value isFirstThread = b.icmp_eq(threadId, b.i32_val(0)); + const int circularHeaderWordSize = proton::gpu::getCircularHeaderSize() / 4; + + Block *prevBlock = op->getBlock(); + + // Add the 'if' block. + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + + // Initialize the `warp_index` section. + for (int warpId = 0; warpId < numWarps; warpId++) { + Value warpIndexOffset = b.i32_val(warpId + circularHeaderWordSize); + Value gmemWarpIndexPtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, warpIndexOffset); + b.store(b.i32_val(0), gmemWarpIndexPtr); + } + + // Add the 'else' block and the condition. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, isFirstThread, ifBlock, thenBlock); + rewriter.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(rewriter, loc, thenBlock); + + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct RestoreCtxOpConversion + : public ConvertOpToLLVMPattern { + explicit RestoreCtxOpConversion(LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::RestoreCtxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto segmentObj = + LLVM::SegmentObject::fromStruct(loc, adaptor.getSegment(), rewriter); + Value scratchPtr = adaptor.getScratchPtr(); + + auto mod = op.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + int numWarps = getTotalNumWarps(mod); + + // We need to use the absolute warp id in case warp specialization is used. + Value threadId = getRawThreadId(rewriter, loc); + + Value warpId = b.udiv( + threadId, + b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod))); + const int circularHeaderWordSize = proton::gpu::getCircularHeaderSize() / 4; + + auto scratchPtrTy = mlir::cast(scratchPtr.getType()); + + // Get the `warp_index` and store it into indexPtr. + Value warpIndexOffset = b.add(warpId, b.i32_val(circularHeaderWordSize)); + Value gmemWarpIndexPtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, warpIndexOffset); + Value index = b.load(i32_ty, gmemWarpIndexPtr); + b.store(index, segmentObj.indexPtr); + + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +struct SaveCtxOpConversion + : public ConvertOpToLLVMPattern { + explicit SaveCtxOpConversion(LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::SaveCtxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value scratchPtr = adaptor.getScratchPtr(); + auto scratchPtrTy = mlir::cast(scratchPtr.getType()); + auto segmentObj = + LLVM::SegmentObject::fromStruct(loc, adaptor.getSegment(), rewriter); + + auto mod = op.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + int numWarps = getTotalNumWarps(mod); + + int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = b.i32_val(numLanes); + + // We need to use the absolute warp id in case warp specialization is used. + Value threadId = getRawThreadId(rewriter, loc); + + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); + Value isWarpMaster = b.icmp_eq(laneId, b.i32_val(0)); + const int circularHeaderWordSize = proton::gpu::getCircularHeaderSize() / 4; + + Block *prevBlock = op->getBlock(); + + // Add the 'if' block. + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + + // Update the `warp_index` section. + Value warpIndexOffset = b.add(warpId, b.i32_val(circularHeaderWordSize)); + Value gmemWarpIndexPtr = + b.gep(scratchPtrTy, i32_ty, scratchPtr, warpIndexOffset); + Value index = b.load(i32_ty, segmentObj.indexPtr); + b.store(index, gmemWarpIndexPtr); + + // Add the 'else' block and the condition. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(prevBlock); + cf::CondBranchOp::create(rewriter, loc, isWarpMaster, ifBlock, thenBlock); + rewriter.setInsertionPointToEnd(ifBlock); + cf::BranchOp::create(rewriter, loc, thenBlock); + + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +Type convertProtonGPUMemDescType(triton::gpu::MemDescType type, + const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + // base ptr + auto ptrType = LLVM::LLVMPointerType::get( + ctx, targetInfo.getAddressSpace(type.getMemorySpace())); + + SmallVector types; + types.push_back(ptrType); + auto rank = type.getRank(); + // offsets + for (auto i = 0; i < rank; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type convertProtonGPUSegmentType(SegmentType type, + const TargetInfoBase &targetInfo) { + auto memorySpace = targetInfo.getAddressSpace(type.getMemorySpace()); + return LLVM::SegmentObject::getStructType(type.getContext(), memorySpace, + targetInfo.getIndexPtrAddrSpace()); +} + +} // namespace + +void populateProtonGPUOpPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, + benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} + +void populateTypeConversions(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo) { + typeConverter.addConversion( + [&](triton::gpu::MemDescType type) -> std::optional { + return convertProtonGPUMemDescType(type, targetInfo); + }); + typeConverter.addConversion( + [&](proton::gpu::SegmentType type) -> std::optional { + return convertProtonGPUSegmentType(type, targetInfo); + }); + typeConverter.addConversion( + [&](triton::PointerType type) -> std::optional { + auto ctx = type.getContext(); + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); + }); +} + +} // namespace proton::gpu +} // namespace mlir::triton diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp new file mode 100644 index 0000000000..11a21f9f50 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp @@ -0,0 +1,68 @@ +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h" +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct CircularStoreOpConversion + : public ConvertOpToLLVMPattern< + mlir::triton::proton::gpu::CircularStoreOp> { + explicit CircularStoreOpConversion( + LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern< + mlir::triton::proton::gpu::CircularStoreOp>(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::CircularStoreOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto dataPack = + lowerCircularStoreOpHelper(op, adaptor.getSegment(), rewriter); + + uint32_t addrSpace = dataPack.addrSpace; + if (addrSpace == 1) { + // TODO(crobeck): see what buffer ops performance looks like here for + // global mem (address space 1) compared to predicated ops to shared + // memory + mlir::LLVM::AMD::llStore(rewriter, loc, dataPack.ptr, dataPack.record, + dataPack.isWriter); + } else if (addrSpace == 3) { + targetInfo.getTritonTargetInfo().storeDShared( + rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, + dataPack.isWriter); + } else { + llvm::report_fatal_error("unsupported address space in circular store"); + } + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +} // namespace + +namespace mlir::triton::proton::gpu::AMD { +void populateProtonGPUOpAMDPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} +} // namespace mlir::triton::proton::gpu::AMD diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AddSchedBarriers.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AddSchedBarriers.cpp new file mode 100644 index 0000000000..d90c35f35e --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AddSchedBarriers.cpp @@ -0,0 +1,64 @@ +#include "Conversion/ProtonGPUToLLVM/Passes.h" +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton::proton::gpu { +#define GEN_PASS_DEF_ADDSCHEDBARRIERS +#include "Conversion/ProtonGPUToLLVM/Passes.h.inc" +} // namespace triton::proton::gpu +} // namespace mlir + +namespace { + +struct AddSchedBarriers + : public mlir::triton::proton::gpu::impl::AddSchedBarriersBase< + AddSchedBarriers> { + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + auto funcOps = triton::proton::gpu::getTritonFunctions(mod); + assert(funcOps.size() == 1 && "Expected exactly one funcOp"); + + IntegerAttr zeroAttrValue = + builder.getI32IntegerAttr(static_cast(0)); + + funcOps[0].walk([&](mlir::triton::proton::gpu::ReadCounterOp op) { + auto loc = op.getLoc(); + if (!isa_and_nonnull(op->getPrevNode())) { + builder.setInsertionPoint(op); + ROCDL::SchedBarrier::create(builder, loc, zeroAttrValue); + } + }); + + funcOps[0].walk([&](mlir::triton::proton::gpu::CircularStoreOp op) { + auto loc = op.getLoc(); + if (!isa_and_nonnull(op->getNextNode())) { + builder.setInsertionPointAfter(op); + ROCDL::SchedBarrier::create(builder, loc, zeroAttrValue); + } + }); + } +}; + +} // namespace + +namespace mlir::triton::proton::gpu { + +std::unique_ptr> createAddSchedBarriersPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::proton::gpu diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..80b387ab2f --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(${PROJECT_SOURCE_DIR}/third_party/amd/include) + +add_triton_library(ProtonAMDGPUToLLVM + TargetInfo.cpp + AMDPatternProtonGPUOpToLLVM.cpp + AddSchedBarriers.cpp + ConvertProtonGPUToLLVM.cpp + + DEPENDS + ProtonAMDGPUConversionPassIncGen + + LINK_LIBS PUBLIC + ProtonGPUToLLVM + TritonAMDGPUToLLVM +) diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/ConvertProtonGPUToLLVM.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/ConvertProtonGPUToLLVM.cpp new file mode 100644 index 0000000000..46fcd6f95a --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/ConvertProtonGPUToLLVM.cpp @@ -0,0 +1,104 @@ +#include "Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h" +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton::proton::gpu { +#define GEN_PASS_DEF_CONVERTPROTONAMDGPUTOLLVM +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h.inc" +} // namespace triton::proton::gpu +} // namespace mlir + +namespace { + +class ProtonLLVMConversionTarget : public ConversionTarget { +public: + explicit ProtonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + addDynamicallyLegalOp( + [](triton::gpu::GlobalScratchAllocOp op) { + return op.getBackend() != "proton"; + }); + } +}; + +struct ConvertProtonAMDGPUToLLVM + : public mlir::triton::proton::gpu::impl::ConvertProtonAMDGPUToLLVMBase< + ConvertProtonAMDGPUToLLVM> { + explicit ConvertProtonAMDGPUToLLVM(std::string arch) { this->arch = arch; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp mod = getOperation(); + auto tritonTargetInfo = mlir::triton::AMD::TargetInfo(arch); + auto protonTargetInfo = + mlir::triton::proton::gpu::AMD::TargetInfo(tritonTargetInfo, arch); + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option, + tritonTargetInfo); + populateTypeConversions(typeConverter, protonTargetInfo); + mlir::triton::proton::gpu::populateProtonGPUOpPatterns( + typeConverter, patterns, protonTargetInfo, 1); + mlir::triton::proton::gpu::AMD::populateProtonGPUOpAMDPatterns( + typeConverter, patterns, protonTargetInfo, 1); + mlir::triton::AMD::populateMaskedOpsToLLVMPatterns(patterns, + tritonTargetInfo); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + + FailureOr maybeChipset = + mlir::amdgpu::Chipset::parse(this->arch); + if (failed(maybeChipset)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid AMDGPU chipset name: " + this->arch); + return signalPassFailure(); + } + mlir::populateGpuToROCDLConversionPatterns( + typeConverter, patterns, mlir::gpu::amd::HIP, *maybeChipset); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + auto convTarget = ProtonLLVMConversionTarget(*context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { + +namespace triton::proton { + +namespace gpu { + +std::unique_ptr> +createConvertProtonAMDGPUToLLVMPass(std::string arch) { + return std::make_unique(arch); +} + +} // namespace gpu + +} // namespace triton::proton + +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.cpp new file mode 100644 index 0000000000..b6c021a334 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,160 @@ +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::proton::gpu::AMD { + +Value TargetInfo::clock(ConversionPatternRewriter &rewriter, Location loc, + bool isClock64) const { + // NV has both a 32 bit and 64 bit clock intrinsic. On AMD we only have + // s_memtime which is 64 bit. However truncating the 64 bit version + // in cases of requesting 32 bit should be fine, since in 64 bits, + // after 0x0000.0000.ffff.ffff comes 0x0000.0001.0000.0000, and + // truncating that to 32 bits gives zero, effectively wrapping from + // 0xffff.ffff to 0x0000.0000. + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringRef clock64IntrinsicName = "llvm.amdgcn.s.memtime"; + Value clockVal = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, clock64IntrinsicName, i64_ty, {}) + .getResult(0); + if (!isClock64) + clockVal = LLVM::TruncOp::create(rewriter, loc, i32_ty, clockVal); + + return clockVal; +} + +Value TargetInfo::globalTime(ConversionPatternRewriter &rewriter, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringRef globalTimeIntrinsicName = "llvm.amdgcn.s.memrealtime"; + Value globalTimeVal = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, globalTimeIntrinsicName, i64_ty, {}) + .getResult(0); + // The clock-generator runs at 100 MHz ==> 10 ns per clock. + // Reference: Section 3.4.11 in the RDNA4 ISA manual + // https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf + return b.mul(globalTimeVal, b.i64_val(10)); +} + +// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h#L898 +// XCC_ID Register bit structure for gfx940-942, gfx950 +// XCC_ID 3:0 XCC the wave is assigned to. +static Value getXCCID(ConversionPatternRewriter &rewriter, Location loc) { + GCNBuilder builder; + auto &gethwid = *builder.create("s_getreg_b32"); + auto xcc_id = builder.newOperand("=s"); + // HW_REG_XCC_ID_OFFSET=0, HW_REG_XCC_ID_SIZE=4 + auto xcc_reg = builder.newConstantOperand("hwreg(HW_REG_XCC_ID, 0, 4)"); + gethwid(xcc_id, xcc_reg); + return builder.launch(rewriter, loc, i32_ty, false); +} + +// HW_ID Register bit structure for GCN and CDNA +// CU_ID 11:8 Compute Unit the wave is assigned to. +static Value getCUID(ConversionPatternRewriter &rewriter, Location loc) { + GCNBuilder builder; + auto &gethwid = *builder.create("s_getreg_b32"); + auto cu_id = builder.newOperand("=s"); + // HW_ID_CU_ID_OFFSET=8, HW_ID_CU_ID_SIZE=4 + auto hwreg = builder.newConstantOperand("hwreg(HW_REG_HW_ID, 8, 4)"); + gethwid(cu_id, hwreg); + return builder.launch(rewriter, loc, i32_ty, false); +} +// SE_ID 15:13 Shader Engine the wave is assigned to for gfx940-942, +// gfx950 +static Value getSEID(ConversionPatternRewriter &rewriter, Location loc) { + GCNBuilder builder; + auto &gethwid = *builder.create("s_getreg_b32"); + auto se_id = builder.newOperand("=s"); + // HW_ID_SE_ID_OFFSET=13, HW_ID_SE_ID_SIZE=3 + auto hwreg = builder.newConstantOperand("hwreg(HW_REG_HW_ID, 13, 3)"); + gethwid(se_id, hwreg); + return builder.launch(rewriter, loc, i32_ty, false); +} + +// gfx942 has 8 XCDs, each XCD contains 40 CUs per XCD but only 38/40 are active +// (total of 304 CUs) gfx950 has 8 XCDs, each XCD contains 36 CUs per XCD but +// only 32/36 active CUs (total 256 CUs) +static uint32_t getCU_PER_XCD(llvm::AMDGPU::GPUKind GPUKind) { + switch (GPUKind) { + case llvm::AMDGPU::GK_GFX942: + return 38; + case llvm::AMDGPU::GK_GFX950: + return 32; + default: + llvm_unreachable("unsupported arch"); + } +} + +static uint32_t getCU_PER_SE(llvm::AMDGPU::GPUKind GPUKind) { + switch (GPUKind) { + case llvm::AMDGPU::GK_GFX942: + return 10; + case llvm::AMDGPU::GK_GFX950: + return 10; + default: + llvm_unreachable("unsupported arch"); + } +} + +Value TargetInfo::processorId(ConversionPatternRewriter &rewriter, + Location loc) const { + GCNBuilder builder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto &gethwid = *builder.create("s_getreg_b32"); + + Value xcc_id = b.i32_val(0); + llvm::AMDGPU::GPUKind GPUKind = llvm::AMDGPU::parseArchAMDGCN(this->arch); + // For now only support gfx942, and gfx950 + switch (GPUKind) { + case llvm::AMDGPU::GK_GFX942: + case llvm::AMDGPU::GK_GFX950: + xcc_id = getXCCID(rewriter, loc); + break; + default: + llvm::report_fatal_error("unsupported arch"); + } + + Value cu_id = getCUID(rewriter, loc); // local CU ID + Value se_id = getSEID(rewriter, loc); + builder.create<>("s_waitcnt lgkmcnt(0)")->operator()(); + + // For XCC based architectures to get a unique CU id for a wave: + // global_cu_id = xcc_id * CU_PER_XCD + se_id * CU_PER_SE + cu_id (local) + if (GPUKind == llvm::AMDGPU::GK_GFX942 || + GPUKind == llvm::AMDGPU::GK_GFX950) { + uint32_t CU_PER_XCD = getCU_PER_XCD(GPUKind); + uint32_t CU_PER_SE = getCU_PER_SE(GPUKind); + cu_id = b.add(b.add(b.mul(xcc_id, b.i32_val(CU_PER_XCD)), + b.mul(se_id, b.i32_val(CU_PER_SE))), + cu_id); + } + + return cu_id; +} + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + int spaceId = 0; + if (mlir::isa(addressSpace)) { + spaceId = 3; + } else if (mlir::isa(addressSpace)) { + spaceId = 1; + } else { + llvm::report_fatal_error("Only support SharedMemorySpace, " + "and GlobalMemorySpace for now"); + } + return spaceId; +} + +int TargetInfo::getIndexPtrAddrSpace() const { + // Internal buffer index is private to each thread, we use thread local + // address space for AMD GPUs. See detail discussion: + // https://llvm.org/docs/AMDGPUUsage.html#address-spaces + return 5; +} + +} // namespace mlir::triton::proton::gpu::AMD diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..523e15d2d5 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/CMakeLists.txt @@ -0,0 +1,38 @@ +# include_directories(${PROJECT_SOURCE_DIR}/third_party/nvidia/include) + +# add_triton_library(ProtonNVIDIAGPUToLLVM +# TargetInfo.cpp +# NvidiaPatternProtonGPUOpToLLVM.cpp +# ConvertProtonGPUToLLVM.cpp + +# DEPENDS +# ProtonNvidiaGPUConversionPassIncGen + +# LINK_LIBS PUBLIC +# ProtonGPUToLLVM +# TritonNVIDIAGPUToLLVM +# ) + +# target_include_directories(ProtonNVIDIAGPUToLLVM PRIVATE +# ${PROJECT_SOURCE_DIR}/third_party/mthreads/proton/Dialect/include +# ) + +add_triton_library(ProtonNVIDIAGPUToLLVM + TargetInfo.cpp + NvidiaPatternProtonGPUOpToLLVM.cpp + ConvertProtonGPUToLLVM.cpp + ../../compat/TargetInfo.cpp + ../../compat/PTXAsmFormat.cpp + ../../compat/Utility.cpp + + DEPENDS + ProtonNvidiaGPUConversionPassIncGen + + LINK_LIBS PUBLIC + ProtonGPUToLLVM + # TritonNVIDIAGPUToLLVM +) + +target_include_directories(ProtonNVIDIAGPUToLLVM PRIVATE + ${PROJECT_SOURCE_DIR}/third_party/mthreads/proton/Dialect/include +) diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/ConvertProtonGPUToLLVM.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/ConvertProtonGPUToLLVM.cpp new file mode 100644 index 0000000000..b9c72cce31 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/ConvertProtonGPUToLLVM.cpp @@ -0,0 +1,106 @@ +#include "Conversion/ProtonGPUToLLVM/PatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h" +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton::proton::gpu { +#define GEN_PASS_DEF_CONVERTPROTONNVIDIAGPUTOLLVM +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h.inc" +} // namespace triton::proton::gpu +} // namespace mlir + +namespace { + +class ProtonLLVMConversionTarget : public ConversionTarget { +public: + explicit ProtonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + addDynamicallyLegalOp( + [](triton::gpu::GlobalScratchAllocOp op) { + return op.getBackend() != "proton"; + }); + } +}; + +struct ConvertProtonNvidiaGPUToLLVM + : public mlir::triton::proton::gpu::impl::ConvertProtonNvidiaGPUToLLVMBase< + ConvertProtonNvidiaGPUToLLVM> { + explicit ConvertProtonNvidiaGPUToLLVM(int32_t computeCapability, + int32_t ptxVersion) { + this->computeCapability = computeCapability; + this->ptxVersion = ptxVersion; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp mod = getOperation(); + + auto tritonTargetInfo = + mlir::triton::NVIDIA::TargetInfo(computeCapability, ptxVersion); + auto protonTargetInfo = + mlir::triton::proton::gpu::NVIDIA::TargetInfo(tritonTargetInfo); + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option, + tritonTargetInfo); + populateTypeConversions(typeConverter, protonTargetInfo); + mlir::triton::proton::gpu::populateProtonGPUOpPatterns( + typeConverter, patterns, protonTargetInfo, 1); + mlir::triton::proton::gpu::NVIDIA::populateProtonGPUOpNvidiaPatterns( + typeConverter, patterns, protonTargetInfo, 1); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + auto convTarget = ProtonLLVMConversionTarget(*context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + + OpPassManager pm; + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(runPipeline(pm, mod))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { + +namespace triton::proton { + +namespace gpu { + +std::unique_ptr> +createConvertProtonNvidiaGPUToLLVMPass(int32_t computeCapability, + int32_t ptxVersion) { + return std::make_unique(computeCapability, + ptxVersion); +} + +} // namespace gpu + +} // namespace triton::proton + +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp new file mode 100644 index 0000000000..911ff139fb --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp @@ -0,0 +1,106 @@ +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.h" +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h" +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "compat/PTXAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +// Circular strategy memory layout of profiled data (total: N bytes). +// Assuming we record data from warp 0, 2, 7 so buffer looks like: +// +-----------------------------------------------+ +// | warp 0 data (N/3 bytes) | +// +-----------------------------------------------+ +// | warp 2 data (N/3 bytes) | +// +-----------------------------------------------+ +// | warp 7 data (N/3 bytes) | +// +-----------------------------------------------+ + +struct CircularStoreOpConversion + : public ConvertOpToLLVMPattern< + mlir::triton::proton::gpu::CircularStoreOp> { + explicit CircularStoreOpConversion( + LLVMTypeConverter &typeConverter, + const proton::gpu::TargetInfoBase &targetInfo, PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern< + mlir::triton::proton::gpu::CircularStoreOp>(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::gpu::CircularStoreOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + auto dataPack = + lowerCircularStoreOpHelper(op, adaptor.getSegment(), rewriter); + + uint32_t addrSpace = dataPack.addrSpace; + if (addrSpace == 1) { + auto mod = op.getOperation()->getParentOfType(); + int numWarps = proton::gpu::getTotalNumWarps(mod); + PTXBuilder builder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (numWarps > 1) { + auto stInst = builder.create<>("st")->o("global").o("cg").v(2).b(32); + auto *ptrOpr = builder.newAddrOperand(dataPack.ptr, "l"); + + PTXBuilder::Operand *valOpr; + SmallVector> vecVals; + auto unPackedVals = unpackLLVector(loc, dataPack.record, rewriter); + vecVals.push_back({unPackedVals[0], "r"}); + vecVals.push_back({unPackedVals[1], "r"}); + valOpr = builder.newListOperand(vecVals); + stInst(ptrOpr, valOpr).predicate(dataPack.isWriter, "b"); + builder.launch(rewriter, loc, void_ty(rewriter.getContext())); + } else { + // Non-vectorized version for num_warps=1 to handle potential + // misalignment + auto stInst = builder.create<>("st")->o("global").o("cg").b(32); + + auto unPackedVals = unpackLLVector(loc, dataPack.record, rewriter); + + // First store: write first 32-bit value at base address + auto *ptrOpr0 = builder.newAddrOperand(dataPack.ptr, "l", 0); + auto *valOpr0 = builder.newOperand(unPackedVals[0], "r"); + stInst(ptrOpr0, valOpr0).predicate(dataPack.isWriter, "b"); + + // Second store: write second 32-bit value at offset +4 bytes + auto *ptrOpr1 = builder.newAddrOperand(dataPack.ptr, "l", 4); + auto *valOpr1 = builder.newOperand(unPackedVals[1], "r"); + stInst(ptrOpr1, valOpr1).predicate(dataPack.isWriter, "b"); + + builder.launch(rewriter, loc, void_ty(rewriter.getContext())); + } + } else if (addrSpace == 3) { + targetInfo.getTritonTargetInfo().storeDShared( + rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, + /*pred=*/dataPack.isWriter); + } else { + llvm::report_fatal_error("unsupported address space in circular store"); + } + rewriter.eraseOp(op); + return success(); + } + +protected: + const proton::gpu::TargetInfoBase &targetInfo; +}; + +} // namespace + +namespace mlir::triton::proton::gpu::NVIDIA { +void populateProtonGPUOpNvidiaPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} +} // namespace mlir::triton::proton::gpu::NVIDIA diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.cpp new file mode 100644 index 0000000000..527053177d --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,81 @@ +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/TargetInfo.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "compat/PTXAsmFormat.h" +#include "compat/Utility.h" // TODO(fywkevin): move Utility.h to include/ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::proton::gpu::NVIDIA { + +Value TargetInfo::clock(ConversionPatternRewriter &rewriter, Location loc, + bool isClock64) const { + + auto getClockReg = [&](const std::string &clkName) { + PTXBuilder builder; + auto &movLow = builder.create("mov")->o("u32"); + auto *destLowOpr = builder.newOperand("=r"); + auto *sRegLowOpr = builder.newConstantOperand(clkName); + movLow(destLowOpr, sRegLowOpr); + Value clkLow32 = + builder.launch(rewriter, loc, rewriter.getIntegerType(32), true); + return clkLow32; + }; + + Value clkLow32 = getClockReg("%clock"); + + if (!isClock64) + return clkLow32; + + Value clkHigh32 = getClockReg("%clock_hi"); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value clkLow64 = b.zext(i64_ty, clkLow32); + Value clkHigh64 = b.zext(i64_ty, clkHigh32); + Value clock64 = b.or_(b.shl(clkHigh64, b.i64_val(32)), clkLow64); + return clock64; +} + +Value TargetInfo::globalTime(ConversionPatternRewriter &rewriter, + Location loc) const { + // globaltimer is a 64-bit global clock counter in nanoseconds. + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-globaltimer + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringRef globalTimeIntrinsicName = "llvm.nvvm.read.ptx.sreg.globaltimer"; + Value globalTimeVal = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, globalTimeIntrinsicName, i64_ty, {}) + .getResult(0); + return globalTimeVal; +} + +Value TargetInfo::processorId(ConversionPatternRewriter &rewriter, + Location loc) const { + return NVVM::SmIdOp::create(rewriter, loc, i32_ty); +} + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + int spaceId = 0; + if (mlir::isa(addressSpace)) { + spaceId = 3; + } else if (mlir::isa(addressSpace)) { + spaceId = 1; + } else { + llvm::report_fatal_error("Only support SharedMemorySpace, " + "and GlobalMemorySpace for now"); + } + return spaceId; +} + +int TargetInfo::getIndexPtrAddrSpace() const { + // Internal buffer index is private to each thread, we use generic address + // space for NV GPUs. See detail discussion: + // https://llvm.org/docs/NVPTXUsage.html#address-spaces + // The reason we don't use address space 5 is due to the downstream compiler + // generates incorrect `cvta` instruction for %SP/%SPL register that causes + // IMA when we perform thread-private memory access like `ld.local`. + return 0; +} + +} // namespace mlir::triton::proton::gpu::NVIDIA diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/Utility.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/Utility.cpp new file mode 100644 index 0000000000..7435428294 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonGPUToLLVM/Utility.cpp @@ -0,0 +1,179 @@ +#include "Conversion/ProtonGPUToLLVM/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { + +Value getRawThreadId(OpBuilder &rewriter, Location loc) { + Value tid = + ::mlir::gpu::ThreadIdOp::create(rewriter, loc, ::mlir::gpu::Dimension::x); + Value threadId = arith::IndexCastOp::create(rewriter, loc, i32_ty, tid); + return threadId; +} + +namespace LLVM { + +LLVMStructType SegmentObject::getStructType(MLIRContext *ctx, int memorySpace, + int indexPtrAddrSpace) { + SmallVector types; + // ------------ + // Memory descriptor + // ------------ + auto ptrType = LLVM::LLVMPointerType::get(ctx, memorySpace); + types.push_back(ptrType); + // ------------ + // Segment base + // ------------ + auto SegmentAllocType = IntegerType::get(ctx, 32); + types.push_back(SegmentAllocType); + // ------------ + // Index ptr + // ------------ + auto indexPtrType = LLVM::LLVMPointerType::get(ctx, indexPtrAddrSpace); + types.push_back(indexPtrType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Value SegmentObject::getStruct(Location loc, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + int memorySpace = + mlir::cast(base.getType()).getAddressSpace(); + int indexPtrAddrSpace = + mlir::cast(indexPtr.getType()).getAddressSpace(); + auto structTy = + getStructType(loc.getContext(), memorySpace, indexPtrAddrSpace); + Value segmentStruct = LLVM::UndefOp::create(rewriter, loc, structTy); + segmentStruct = b.insert_val(structTy, segmentStruct, base, 0); + segmentStruct = b.insert_val(structTy, segmentStruct, segmentBase, 1); + segmentStruct = b.insert_val(structTy, segmentStruct, indexPtr, 2); + return segmentStruct; +} + +SegmentObject SegmentObject::fromStruct(Location loc, Value segmentStruct, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto structTy = mlir::cast(segmentStruct.getType()); + Value memoryDescriptorPtr = + b.extract_val(structTy.getBody()[0], segmentStruct, 0); + Value segmentBase = b.extract_val(structTy.getBody()[1], segmentStruct, 1); + Value indexPtr = b.extract_val(structTy.getBody()[2], segmentStruct, 2); + return SegmentObject(memoryDescriptorPtr, segmentBase, indexPtr); +} + +} // namespace LLVM + +namespace triton { +namespace proton::gpu { + +CircularStoreDataPack +lowerCircularStoreOpHelper(CircularStoreOp op, Value segmentStruct, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto mod = op.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + const int bytesPerEntry = proton::gpu::getBytesPerClockEntry(); + const int wordsPerEntry = bytesPerEntry / 4; // 1 word = 4 bytes + + auto segmentObj = + LLVM::SegmentObject::fromStruct(loc, segmentStruct, rewriter); + Value indexPtr = segmentObj.indexPtr; + Value bufferBase = segmentObj.base; + Value segmentBase = segmentObj.segmentBase; + + // Update the index (could be register promoted). + Value curIdx = b.load(i32_ty, indexPtr); + Value newIdx = b.add(curIdx, b.i32_val(wordsPerEntry)); + + // Compute the segment size in word (4 bytes). + int selectedWarpNum = getTotalNumWarps(mod); + auto segmentType = op.getSegment().getType(); + auto selectedIds = segmentType.getSelectIds(); + if (!selectedIds.empty()) + selectedWarpNum = selectedIds.size(); + const int bufferSizeInBytes = segmentType.getNBytes(); + const int segmentWordSize = bufferSizeInBytes / selectedWarpNum / 4; + + // Compute the actual base offset (with urem as circular buffer). + Value tagOffset = + b.add(segmentBase, b.urem(curIdx, b.i32_val(segmentWordSize))); + + // Store the counter into buffer. + auto bufferBaseType = bufferBase.getType(); + Value vecPtr = b.gep(bufferBaseType, i32_ty, bufferBase, tagOffset); + + // Constructing the tag and clock (8 byte) + // ======================================= + // tag and upper clock (4 bytes): + // 31: start or end (1 bit) + // 30:23 scope id (8 bits) + // 22:11 reserved (12 bits) + // 10:0 64-bit clock bit 32:42 (11 bits) + // ======================================= + // lower clock (4 bytes): + // 31:0 64-bit clock bit 0:31 + // ======================================= + Value clock = op.getCounter(); + auto clkTy = mlir::cast(clock.getType()); + uint32_t maskedScopeId = op.getScopeId() & 0xff; + Value tag = op.getIsStart() ? b.i32_val(maskedScopeId << 23) + : b.i32_val(1 << 31 | maskedScopeId << 23); + Value valsVec; + if (clkTy.getWidth() == 64) { + auto clkVecTy = vec_ty(i32_ty, 2); + auto clkVec = b.bitcast(clock, clkVecTy); + Value clkLower = b.extract_element(i32_ty, clkVec, b.i32_val(0)); + Value clkUpper = b.extract_element(i32_ty, clkVec, b.i32_val(1)); + Value tagClkUpper = b.or_(tag, b.and_(clkUpper, b.i32_val(0x7ff))); + valsVec = packLLVector(loc, {tagClkUpper, clkLower}, rewriter); + } else { + valsVec = packLLVector(loc, {tag, clock}, rewriter); + } + + // Compute the predicate for the writer. + const int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value curThreadId = getThreadId(rewriter, loc); + Value isWarpMaster = + b.icmp_eq(b.urem(curThreadId, b.i32_val(warpSize)), b.i32_val(0)); + Value isWriter; + + Value idxToStore = newIdx; + auto granularity = segmentType.getGranularity(); + if (selectedIds.empty()) { + if (granularity == proton::gpu::Granularity::WARP) { + isWriter = isWarpMaster; + } else { + llvm::report_fatal_error( + "segment address specialization not implemented yet"); + } + } else { + Value isCurWarpEnabled = b.icmp_ne(segmentBase, b.i32_val(-1)); + isWriter = b.and_(isCurWarpEnabled, isWarpMaster); + idxToStore = b.select(isCurWarpEnabled, newIdx, curIdx); + } + + b.store(idxToStore, indexPtr); + + uint32_t addrSpace = + cast(bufferBaseType).getAddressSpace(); + + return {isWriter, valsVec, vecPtr, addrSpace}; +} + +SmallVector getTritonFunctions(ModuleOp mod) { + SmallVector funcOps; + mod.walk([&](FunctionOpInterface funcOp) { + // Ignore any intrinsic functions which have an empty body. + // For example, on AMD the predicate load/store ops are currently pseudo + // instructions at this point and may get picked up here and trigger the + // FunctionOpInterface range based assert below. + if (funcOp.empty()) + return; + funcOps.push_back(funcOp); + }); + return funcOps; +} + +} // namespace proton::gpu +} // namespace triton + +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/CMakeLists.txt b/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/CMakeLists.txt new file mode 100644 index 0000000000..f0a56252a8 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(ProtonToProtonGPU + ProtonToProtonGPUPass.cpp + + DEPENDS + ProtonToProtonGPUIncGen + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + ProtonIR + ProtonGPUIR +) diff --git a/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp b/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp new file mode 100644 index 0000000000..f3f7049d62 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp @@ -0,0 +1,416 @@ +#include "Analysis/ScopeIdAllocation.h" +#include "Conversion/ProtonToProtonGPU/Passes.h" +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include + +namespace mlir { +namespace triton { +namespace proton { + +#define GEN_PASS_DEF_CONVERTPROTONTOPROTONGPU +#include "Conversion/ProtonToProtonGPU/Passes.h.inc" + +#define DEBUG_TYPE "proton-to-proton-gpu" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +constexpr float maxSharedMemRatio = 0.04; // 4 percent of max shared mem + +namespace { + +void parseSelectIds(llvm::StringRef selectIds, + llvm::SmallVectorImpl &selectIdVec) { + auto rest = selectIds; + while (!rest.empty()) { + llvm::StringRef id; + std::tie(id, rest) = rest.split(','); + if (id.trim().size() > 0) { + selectIdVec.push_back(std::stoi(id.str())); + } + if (rest.trim().size() == 0) + break; + } + llvm::sort(selectIdVec); + selectIdVec.erase(llvm::unique(selectIdVec), selectIdVec.end()); +} + +template bool hasOperator(T *o) { + bool exist = false; + o->walk([&](OP op) { + exist = true; + return WalkResult::interrupt(); + }); + return exist; +} + +void instrumentWarpSpecializeOps(FuncOp func, Value buffer, Value profileMem) { + for (auto wsOp : func.getOps()) { + auto loc = wsOp.getLoc(); + if (hasOperator(wsOp.getOperation())) { + auto partOp = wsOp.getPartitionOp(); + partOp->insertOperands(partOp->getNumOperands(), {buffer, profileMem}); + for (Region *region : wsOp.getPartitionRegions()) { + region->addArgument(buffer.getType(), loc); + region->addArgument(profileMem.getType(), loc); + } + } + } +} + +LogicalResult replaceProtonRecordOp(OpBuilder &builder, FuncOp func, + Value segment, MetricType metricType, + ModuleScopeIdAllocation &scopeInfo, + bool clockExtension) { + mlir::IntegerType clkType = + clockExtension ? mlir::IntegerType::get(builder.getContext(), 64) + : mlir::IntegerType::get(builder.getContext(), 32); + + // Replace all proton::RecordOp in the worker warps. + func->walk([&](triton::gpu::WarpSpecializePartitionsOp partitions) { + for (auto &partition : partitions.getPartitionRegions()) { + auto loc = partitions.getLoc(); + if (hasOperator(&partition)) { + Block &block = partition.front(); + builder.setInsertionPointToStart(&block); + int argNum = block.getNumArguments(); + auto bufferArg = block.getArgument(argNum - 2); + auto profileMemArg = block.getArgument(argNum - 1); + + // Create a new segment for the worker warp. + Value newSegment = gpu::SegmentAllocOp::create( + builder, loc, segment.getType(), bufferArg); + + // Restore warp-level context before profiling. + gpu::RestoreCtxOp::create(builder, loc, newSegment, profileMemArg); + + // Replace all proton::RecordOp. + partition.walk([&](proton::RecordOp record) { + builder.setInsertionPoint(record); + + Value counter = gpu::ReadCounterOp::create(builder, record.getLoc(), + clkType, metricType); + int scopeId = scopeInfo.getOpScopeId(record); + gpu::CircularStoreOp::create(builder, record.getLoc(), newSegment, + counter, record.getIsStart(), scopeId); + record.erase(); + }); + + // Finalize and save warp-level context before each warp returns. + partition.walk([&](triton::gpu::WarpReturnOp ret) { + builder.setInsertionPoint(ret); + // TODO(Keren): This is not ideal if we have multiple warp specialize + // ops in a program. In that case, we should use SaveCtxOp here at + // warp return and only write back data in FinalizeOp at the end of + // kernel. Active warps in the default warp group can write data on + // behalf of inactive warps in other warp groups. + gpu::FinalizeOp::create(builder, loc, newSegment, profileMemArg); + }); + } + } + }); + + // Replace all proton::RecordOp in the master warps. For the master warps, we + // don't need to restore warp-level context and we save the context in the end + // of kernel (right before FinalizeOp). + func->walk([&](proton::RecordOp record) { + builder.setInsertionPoint(record); + Value counter = gpu::ReadCounterOp::create(builder, record.getLoc(), + clkType, metricType); + int scopeId = scopeInfo.getOpScopeId(record); + gpu::CircularStoreOp::create(builder, record.getLoc(), segment, counter, + record.getIsStart(), scopeId); + record.erase(); + }); + + return success(); +} + +int getAllocSharedMemSize(int maxSharedMemSize, int sharedMemUsed, + int segmentNum) { + const int bytesPerEntry = gpu::getBytesPerClockEntry(); + const int wordsPerEntry = bytesPerEntry / 4; // 1 word = 4 bytes + const int circularHeaderSize = gpu::getCircularHeaderSize(); // byte size + sharedMemUsed = llvm::alignTo(sharedMemUsed, bytesPerEntry); + if (sharedMemUsed >= maxSharedMemSize) { + // We just assume there's enough shared memory and error out if not during + // execution. + maxSharedMemSize += sharedMemUsed; + } + + int segmentByteSizeShared = + llvm::NextPowerOf2((maxSharedMemSize - sharedMemUsed) / segmentNum) / 2; + int numSharedEntries = segmentByteSizeShared * segmentNum / bytesPerEntry; + int allocSharedMemSize = numSharedEntries * bytesPerEntry; + + int estimatedOccupancy = maxSharedMemSize / std::max(1, sharedMemUsed); + if (estimatedOccupancy <= 1) + return allocSharedMemSize; + + int maxAllocSharedMemSize = maxSharedMemSize * maxSharedMemRatio; + while (allocSharedMemSize > maxAllocSharedMemSize) + allocSharedMemSize /= 2; + + return allocSharedMemSize; +} +} // namespace + +class ConvertProtonToProtonGPUPass + : public impl::ConvertProtonToProtonGPUBase { +public: + ConvertProtonToProtonGPUPass( + MetricType metricType, SamplingStrategy samplingStrategy, + llvm::StringRef samplingOptions, gpu::Granularity granularity, + gpu::BufferStrategy bufferStrategy, gpu::BufferType bufferType, + int32_t bufferSize, int32_t maxSharedMemSize, int64_t profileScratchSize, + int32_t profileScratchAlignment, bool clockExtension) + : ConvertProtonToProtonGPUBase() { + this->metricType = metricType; + this->samplingStrategy = samplingStrategy; + this->granularity = granularity; + this->samplingOptions = samplingOptions.str(); + this->bufferStrategy = bufferStrategy; + this->bufferType = bufferType; + this->bufferSize = bufferSize; + this->maxSharedMemSize = maxSharedMemSize; + this->profileScratchSize = profileScratchSize; + this->profileScratchAlignment = profileScratchAlignment; + this->clockExtension = clockExtension; + } + + LogicalResult circularRecordStrategyLowering(FuncOp func) { + MLIRContext *context = func.getContext(); + Location loc = func->getLoc(); + ModuleOp mod = llvm::cast(func->getParentOp()); + + OpBuilder builder(context); + builder.setInsertionPointToStart(&func.getBody().front()); + + int numWarps = gpu::getTotalNumWarps(mod); + + llvm::SmallVector selectIdVec; + int segmentNum = numWarps; + if (!samplingOptions.empty() && + samplingStrategy == SamplingStrategy::SELECTIVE) { + parseSelectIds(samplingOptions, selectIdVec); + segmentNum = selectIdVec.size(); + if (segmentNum && granularity != gpu::Granularity::WARP) { + mlir::emitError( + loc, "only warp granularity supports selective ids for now."); + return failure(); + } + } + + int sharedMemUsed = 0; + if (mod->hasAttr("ttg.shared")) + sharedMemUsed = + mod->getAttrOfType("ttg.shared").getInt(); + + int numCTAs = triton::gpu::lookupNumCTAs(func); + auto maxSharedMemSizePerCTA = maxSharedMemSize / numCTAs; + + int allocSharedMemSize = getAllocSharedMemSize(maxSharedMemSizePerCTA, + sharedMemUsed, segmentNum); + + const int bytesPerEntry = gpu::getBytesPerClockEntry(); + + if (bufferSize != 0) + bufferSize = llvm::alignTo(bufferSize, bytesPerEntry); + // Validate buffer size + if (bufferSize != 0 && !llvm::isPowerOf2_32(bufferSize / segmentNum)) { + mlir::emitError(loc, "buffer-size per segment(" + + llvm::Twine(segmentNum) + + ") must be power of 2"); + return failure(); + } + + int allocBufferSize; + if (bufferType == gpu::BufferType::SHARED) { + if (bufferSize > 0) + allocBufferSize = std::min(allocSharedMemSize, bufferSize.getValue()); + else + allocBufferSize = allocSharedMemSize; + } else if (bufferType == gpu::BufferType::GLOBAL) { + if (bufferSize > 0) + allocBufferSize = bufferSize.getValue(); + else + allocBufferSize = 16384 * segmentNum; // 16KB per profiling unit + } else { + mlir::emitError(loc, "buffer-type not supported"); + return failure(); + } + + if (allocBufferSize <= 0) { + mlir::emitError(loc, "profiling buffer size should be greater than 0"); + return failure(); + } + + // Circular strategy memory layout (total: allocProfileScratchSize bytes) + // +-----------------------------------------------+ + // | header (circularHeaderSize bytes) | + // +-----------------------------------------------+ + // | contexts for all warps (4 bytes x numWarps) | + // +-----------------------------------------------+ + // | profiled data (allocBufferSize bytes) | + // +-----------------------------------------------+ + const int circularHeaderSize = gpu::getCircularHeaderSize(); // byte size + + int allocProfileScratchSize = + llvm::alignTo(allocBufferSize + circularHeaderSize + numWarps * 4, + profileScratchAlignment); + + if (profileScratchSize < allocProfileScratchSize) { + LDBG("Global scratch memory for proton profiling is not large " + "enough, we allocate the scratch size as " + + llvm::Twine(allocProfileScratchSize) + " bytes."); + } + + Value profileMem = triton::gpu::GlobalScratchAllocOp::create( + builder, loc, triton::getPointerType(builder.getI32Type()), + allocProfileScratchSize, profileScratchAlignment, "proton"); + gpu::InitializeOp::create(builder, loc, profileMem); + + Value segment; + Value buffer; + if (bufferType == gpu::BufferType::SHARED) { + auto cgaLayout = + triton::gpu::CGAEncodingAttr::get1DLayout(context, numCTAs); + auto encoding = triton::gpu::SwizzledSharedEncodingAttr::get( + context, 1, 1, 1, {0}, cgaLayout); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto sharedBufferType = triton::gpu::MemDescType::get( + {allocBufferSize / 4}, builder.getI32Type(), encoding, + sharedMemorySpace, /*mutable_memory=*/true); + buffer = + triton::gpu::LocalAllocOp::create(builder, loc, sharedBufferType); + Attribute memorySpace = + mlir::cast(buffer.getType()) + .getMemorySpace(); + + auto segmentType = gpu::SegmentType::get( + context, allocBufferSize, memorySpace, granularity, selectIdVec); + segment = gpu::SegmentAllocOp::create(builder, loc, segmentType, buffer); + } else if (bufferType == gpu::BufferType::GLOBAL) { + Attribute memorySpace = gpu::GlobalMemorySpaceAttr::get(context); + auto segmentType = gpu::SegmentType::get( + context, allocBufferSize, memorySpace, granularity, selectIdVec); + int offset = (circularHeaderSize + numWarps * 4) / 4; + Type offsetType = builder.getI32Type(); + Value offsetVal = arith::ConstantOp::create( + builder, loc, offsetType, builder.getIntegerAttr(offsetType, offset)); + buffer = triton::AddPtrOp::create(builder, loc, profileMem.getType(), + profileMem, offsetVal); + segment = gpu::SegmentAllocOp::create(builder, loc, segmentType, buffer); + } else { + mlir::emitError(loc, "buffer-type not supported"); + return failure(); + } + + ModuleScopeIdAllocation &scopeInfo = getAnalysis(); + + if (hasOperator( + func.getOperation())) + gpu::InitCtxOp::create(builder, loc, profileMem); + + instrumentWarpSpecializeOps(func, buffer, profileMem); + + if (failed(replaceProtonRecordOp(builder, func, segment, metricType, + scopeInfo, clockExtension))) + return failure(); + + func.walk([&](triton::ReturnOp ret) { + builder.setInsertionPoint(ret); + mlir::triton::gpu::BarrierOp::create( + builder, loc, + triton::gpu::AddrSpace::Local | triton::gpu::AddrSpace::GlobalRead | + triton::gpu::AddrSpace::GlobalWrite); + + gpu::FinalizeOp::create(builder, loc, segment, profileMem); + }); + + return success(); + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + Location loc = m->getLoc(); + + // Validate metric type at runtime instead of using assert + if (metricType != MetricType::CYCLE) { + mlir::emitError(loc, "only CYCLE metric type is supported currently"); + signalPassFailure(); + return; + } + + // Check if there are any functions in the module + int numFuncs = llvm::range_size(m.getOps()); + if (numFuncs == 0) { + return; // No functions to process, silently return + } else if (numFuncs > 1) { + // We currently only support one function in the module + mlir::emitError(loc, "only one function per module is supported"); + signalPassFailure(); + return; + } + + FuncOp func = *m.getOps().begin(); + + // Check if there are any proton records to process + if (!hasOperator(func.getOperation())) { + return; // No proton records to process, silently return + } + + // Validate profile scratch alignment + if (!llvm::isPowerOf2_32(profileScratchAlignment)) { + mlir::emitError(loc, "profileScratchAlignment must be power of 2"); + signalPassFailure(); + return; + } + + // Process based on buffer strategy + if (bufferStrategy == gpu::BufferStrategy::CIRCULAR) { + if (failed(circularRecordStrategyLowering(func))) { + // No need to call signalPassFailure() here as it's already called in + // circularRecordStrategyLowering + signalPassFailure(); + } + } else { + mlir::emitError( + loc, "buffer-strategy '" + + std::to_string(static_cast( + static_cast(bufferStrategy))) + + "' is not supported"); + signalPassFailure(); + } + } +}; + +std::unique_ptr> createConvertProtonToProtonGPUPass( + MetricType metricType, SamplingStrategy samplingStrategy, + llvm::StringRef samplingOptions, gpu::Granularity granularity, + gpu::BufferStrategy bufferStrategy, gpu::BufferType bufferType, + int32_t bufferSize, int32_t maxSharedMemSize, int64_t profileScratchSize, + int32_t profileScratchAlignment, bool clkExt) { + return std::make_unique( + metricType, samplingStrategy, samplingOptions, granularity, + bufferStrategy, bufferType, bufferSize, maxSharedMemSize, + profileScratchSize, profileScratchAlignment, clkExt); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/compat/PTXAsmFormat.cpp b/third_party/mthreads/proton/Dialect/lib/compat/PTXAsmFormat.cpp new file mode 100644 index 0000000000..23c72655e9 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/compat/PTXAsmFormat.cpp @@ -0,0 +1,237 @@ +#include "compat/PTXAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" +#include "llvm/Support/raw_ostream.h" +// TODO(Superjomn): unify to llvm::raw_string_ostream +#include + +namespace mlir { +namespace triton { + +PTXInstr::Operand * +PTXBuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formatter) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formatter; + opr->idx = oprCounter++; + return opr; +} + +void PTXBuilder::initOperand(Operand *opr) { + auto numBits = 0; + // Derive numBits from the constraint. + if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h') + numBits = 16; + else if (opr->constraint[1] == 'r') + numBits = 32; + else if (opr->constraint[1] == 'l') + numBits = 64; + else + llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str()); + // If numBits is less than 16, we use 16 as default because PTX does not + // support 8-bit mov. + numBits = numBits < 16 ? 16 : numBits; + auto *zero = newConstantOperand(0); + auto &init = create<>("mov")->o("u" + std::to_string(numBits)); + init(opr, zero); +} + +PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) { + // Constraint should be something like "=r" + assert(constraint.size() == 2 && constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + if (init) { + initOperand(opr); + } + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) { + assert(operandIndex < oprCounter && "operand index out of range"); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = std::to_string(operandIndex); + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string PTXBuilder::getConstraints() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector PTXBuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector PTXBuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +mlir::Value PTXBuilder::launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = LLVM::InlineAsmOp::create( + rewriter, loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::TailCallKind::None, + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + +std::string PTXInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return "$" + std::to_string(idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + std::stringstream ss; + ss << "[ $" << idx << " + " << off << " ]"; + return ss.str(); + }; + + return opr; +} + +std::string PTXBuilder::dump() const { + llvm::SmallVector lines; + for (auto &exec : executions) { + lines.push_back(exec->dump()); + } + + return strJoin(lines, "\n\t"); +} + +PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + if (onlyAttachMLIRArgs) { + // Nearly impossible to make the $0,$1 in two PTX code snippets to point to + // the same MLIR values in onlyAttachMLIRArgs mode. + assert(builder->executions.empty() && + "builder can only hold a single execution when onlyAttachMIIRArgs " + "is true."); + builder->reorderArgArchive(oprs); + } + + builder->executions.emplace_back( + std::make_unique(this, oprs, onlyAttachMLIRArgs)); + + return *builder->executions.back(); +} + +PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + return call(oprs, onlyAttachMLIRArgs); +} + +std::string PTXInstrExecution::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + + if (pred) { + if (!pred->repr) + os << "@" << pred->dump() << " "; + else + os << pred->repr(pred->idx) << " "; + } + + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) { + os << instrRepr; + os.flush(); + return osStr; + } + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + os << instrRepr << " " << argsRepr << ";"; + os.flush(); + return osStr; +} + +SmallVector +PTXInstrExecution::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +PTXInstr &PTXInstr::global() { + o("global"); + return *this; +} + +PTXInstr &PTXInstr::shared() { + o("shared"); + return *this; +} + +PTXInstr &PTXInstr::v(int vecWidth, bool predicate) { + if (vecWidth > 1) { + o("v" + std::to_string(vecWidth), predicate); + } + return *this; +} + +PTXInstr &PTXInstr::b(int width) { + o("b" + std::to_string(width)); + return *this; +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/lib/compat/TargetInfo.cpp b/third_party/mthreads/proton/Dialect/lib/compat/TargetInfo.cpp new file mode 100644 index 0000000000..2f39619c11 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/compat/TargetInfo.cpp @@ -0,0 +1,617 @@ +#include "compat/TargetInfo.h" +#include "compat/PTXAsmFormat.h" +#include "compat/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; + +using ::mlir::LLVM::linearize; +namespace { +// declare vprintf(i8*, i8*) as external function +LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("vprintf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + SmallVector argsType{ptr_ty(context), ptr_ty(context)}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return LLVM::LLVMFuncOp::create(rewriter, UnknownLoc::get(context), funcName, + funcType); +} + +// extend integer to int32, extend float to float64 +// this comes from vprintf alignment requirements. +std::pair printfPromoteValue(RewriterBase &rewriter, Value value, + bool isSigned) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + Value newOp = value; + Type newType = type; + auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + newType = i32_ty; + if (isSigned) { + newOp = b.sext(newType, value); + } else { + newOp = b.zext(newType, value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newOp = b.fpext(newType, value); + } + + return {newType, newOp}; +} + +LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("__assertfail"); + { + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + } + // void __assert_fail(const char * assertion, const char * file, unsigned + // int line, const char * function); + auto *ctx = rewriter.getContext(); + SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), + rewriter.getIntegerType(sizeof(size_t) * 8)}; + auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcOp = LLVM::LLVMFuncOp::create(rewriter, UnknownLoc::get(ctx), + funcName, funcType); + + funcOp.setPassthroughAttr( + ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn"))); + return funcOp; +} +} // namespace + +namespace mlir::triton::NVIDIA { + +// Check if the reduction can use a redux op and return the kind. +static std::optional matchReduxKind(triton::ReduceOp op, + int computeCapability, + bool &useNanQualifier) { + useNanQualifier = false; + if (computeCapability < 80) + return std::nullopt; + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) + return std::nullopt; + if (computeCapability == 100 && reduceOp->getResultTypes()[0].isF32()) { + if (isa(reduceOp)) + useNanQualifier = true; + if (isa(reduceOp)) + return NVVM::ReduxKind::FMAX; + if (isa(reduceOp)) + return NVVM::ReduxKind::FMIN; + } + auto intType = dyn_cast(reduceOp->getResultTypes()[0]); + if (!intType || intType.getWidth() > 32) + return std::nullopt; + if (isa(reduceOp)) + return NVVM::ReduxKind::ADD; + if (isa(reduceOp)) + return NVVM::ReduxKind::AND; + if (isa(reduceOp)) + return NVVM::ReduxKind::OR; + if (isa(reduceOp)) + return NVVM::ReduxKind::XOR; + if (isa(reduceOp)) + return NVVM::ReduxKind::MIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::MAX; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMAX; + return std::nullopt; +} + +bool TargetInfo::supportMaximumMinimum() const { + return computeCapability >= 80; +} + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + if (triton::gpu::lookupNumCTAs(&rewriter.getInsertionBlock()->front()) == 1) + return arith::ConstantIntOp::create(rewriter, loc, 0, 32); + + return triton::nvgpu::ClusterCTAIdOp::create(rewriter, loc, + rewriter.getI32Type()); +} + +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadMask = b.int_val(type.getIntOrFloatBitWidth(), -1); + return NVVM::VoteSyncOp::create(rewriter, loc, type, threadMask, cmp, + NVVM::VoteSyncKind::ballot); +} + +void TargetInfo::barrier(Location loc, RewriterBase &rewriter, + triton::gpu::AddrSpace targets) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + b.barrier(targets); +} + +void TargetInfo::warpSync(Location loc, RewriterBase &rewriter) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + NVVM::SyncWarpOp::create(rewriter, loc, b.i32_val(0xffffffff)); +} + +static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, + Value pred) { + return NVVM::MapaOp::create(rewriter, loc, ptr.getType(), ptr, ctaid); +} + +static std::string getConstraintForBitwidth(unsigned bitwidth) { + switch (bitwidth) { + case 8: + case 16: + return "h"; + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported bitwidth"); + } +} + +static bool isConstantTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() == -1; + } + return false; +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(val.getType())) { + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, {val}, rewriter), + pred); + return; + } + + auto vecTy = cast(val.getType()); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + v = b.zext(int_ty(8), b.bitcast(v, int_ty(elemBitwidth))); + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + if (!elemTy.isInteger()) { + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + if (isa(v.getType())) { + v = b.ptrtoint(int_ty(elemBitwidth), v); + } else { + v = b.bitcast(v, int_ty(elemBitwidth)); + } + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, store b32's instead. + // 2. Otherwise, split the store into multiple stores. + if (vec > 4 && elemBitwidth < 32) { + assert(llvm::isPowerOf2_32(vec)); + int elemsPerPack = 32 / elemBitwidth; + SmallVector oldVals = unpackLLVector(loc, val, rewriter); + + SmallVector newVals; + for (int i = 0; i < vec / elemsPerPack; i++) { + Value v = packLLVector( + loc, ArrayRef(oldVals).slice(i * elemsPerPack, elemsPerPack), + rewriter); + newVals.push_back(b.bitcast(v, i32_ty)); + } + storeDShared(rewriter, loc, ptr, ctaId, + packLLVector(loc, newVals, rewriter), pred); + return; + } + + if (vec * elemBitwidth > 128) { + assert(llvm::isPowerOf2_32(vec)); + assert(elemBitwidth == 32 || elemBitwidth == 64); + int maxVec = 128 / elemBitwidth; + + auto newVecTy = vec_ty(elemTy, maxVec); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + LLVM::GEPNoWrapFlags::inbounds); + storeDShared( + rewriter, loc, newPtr, ctaId, + packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), + pred); + } + return; + } + + // At this point we're committed to doing the store! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } + + PTXBuilder builder; + auto st = builder.create("st") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + + if (isConstantTruePred(pred)) { + b.store(val, ptr, /*align=*/vec * elemBitwidth / 8); + } else { + PTXBuilder::Operand *valOpr; + std::string constraint = getConstraintForBitwidth(elemBitwidth); + if (vec > 1) { + SmallVector> vecVals; + for (int i = 0; i < vec; i++) { + vecVals.push_back({b.extract_element(val, b.i32_val(i)), constraint}); + } + valOpr = builder.newListOperand(vecVals); + } else { + valOpr = builder.newOperand(val, constraint); + } + st(ptrOpr, valOpr).predicate(pred, "b"); + builder.launch(rewriter, loc, void_ty(ctx)); + } +} + +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred, Operation *localLoadOp) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(loadTy)) { + SmallVector values = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), + rewriter); + assert(values.size() == 1); + return values[0]; + } + + auto vecTy = cast(loadTy); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); + assert(vals.size() == 1); + return b.bitcast(b.trunc(int_ty(elemBitwidth), vals[0]), elemTy); + } + + // We only know how to load integers. + if (!elemTy.isInteger()) { + Type newLoadTy = vec_ty(int_ty(elemBitwidth), vec); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); + for (Value &v : vals) { + v = b.bitcast(v, elemTy); + } + return packLLVector(loc, vals, rewriter); + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, load b32's instead. + // 2. Otherwise, split the load into multiple loads. + if (vec > 4 && elemBitwidth < 32) { + int newVec = vec / (32 / elemBitwidth); + auto newVecTy = vec_ty(i32_ty, newVec); + auto res = loadDShared(rewriter, loc, ptr, ctaId, newVecTy, pred); + + // Unpack the b32's into the original vector type. + SmallVector vals; + for (Value v : unpackLLVector(loc, res, rewriter)) { + Value vv = b.bitcast(v, vec_ty(elemTy, 32 / elemBitwidth)); + for (Value vvv : unpackLLVector(loc, vv, rewriter)) { + vals.push_back(vvv); + } + } + return packLLVector(loc, vals, rewriter); + } + + if (vec * elemBitwidth > 128) { + assert(elemBitwidth == 32 || elemBitwidth == 64); + assert(llvm::isPowerOf2_32(vec)); + int maxVec = 128 / elemBitwidth; + + SmallVector vals; + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + LLVM::GEPNoWrapFlags::inbounds); + auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, + vec_ty(elemTy, maxVec), pred); + for (Value v : unpackLLVector(loc, newVal, rewriter)) { + vals.push_back(v); + } + } + return packLLVector(loc, vals, rewriter); + } + + // At this point we're committed to actually do the load! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } + + PTXBuilder builder; + auto ld = builder.create("ld") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + + Value load; + if (isConstantTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = b.load(resultTy, ptr, /*align=*/vec * elemBitwidth / 8); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = b.undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = b.insert_val(structTy, structValue, + b.extract_element(load, b.i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } + SmallVector resultVals = unpackLLElements(loc, load, rewriter); + return packLLVector(loc, resultVals, rewriter); +} + +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleUp(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { + return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); +} + +Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a, + Value b, Value selector) const { + return LLVM::NVIDIA::permute(loc, rewriter, a, b, selector); +} + +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, ProgramIDDim axis) const { + return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); +} +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, + unsigned interleave) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + bool useNanQualifier = false; + if (auto kind = matchReduxKind(op, computeCapability, useNanQualifier)) { + // Based on benchmarking on A100 redux op gives a speed up only when doing + // a single reduction (not partitioned) and when the mask is static. + // Therefore we currently only enable it to reduce across all the lanes. + if (numLaneToReduce == 32) { + assert(acc.size() == 1); + Value mask = b.i32_val(0xFFFFFFFF); + // Even though we currently don't use redux for partitioned reduction + // the code below supports it in case we want to tweak the heuristic. + if (numLaneToReduce < 32) { + // For partitioned reduction we need to calculate the mask so that + // each group of numLaneToReduce threads has the correct mask. + unsigned bitmask = (1 << numLaneToReduce) - 1; + Value laneId = getLaneId(rewriter, loc); + mask = b.shl(b.i32_val(bitmask), + b.and_(laneId, b.i32_val(~(numLaneToReduce - 1)))); + } + for (unsigned i = 0; i < acc.size(); ++i) { + unsigned bitwidth = acc[i].getType().getIntOrFloatBitWidth(); + if (acc[i].getType().isInteger()) { + if (bitwidth < 32) { + if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX) + acc[i] = b.sext(i32_ty, acc[i]); + else + acc[i] = b.zext(i32_ty, acc[i]); + } + } + acc[i] = NVVM::ReduxOp::create(rewriter, loc, acc[i].getType(), acc[0], + *kind, mask, /*abs=*/false, + /*nan=*/useNanQualifier); + if (acc[i].getType().isInteger()) { + if (bitwidth < 32) + acc[i] = b.trunc(int_ty(bitwidth), acc[i]); + } + } + return true; + } + } + return false; +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + std::string funcName = + resultElementTy.isInteger(32) ? "__nv_umulhi" : "__nv_umul64hi"; + return funcName; +} + +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args, + ArrayRef isSigned) const { + auto *ctx = rewriter.getContext(); + Type ptr = ptr_ty(ctx); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getVprintfDeclaration(rewriter); + auto loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Value one = b.i32_val(1); + Value zero = b.i32_val(0); + + Value bufferPtr = b.null(ptr); + + SmallVector newArgs; + if (args.size() >= 1) { + SmallVector argTypes; + for (auto [i, arg] : llvm::enumerate(args)) { + Type newType; + Value newArg; + std::tie(newType, newArg) = printfPromoteValue( + rewriter, arg, isSigned.empty() ? true : isSigned[i]); + argTypes.push_back(newType); + newArgs.push_back(newArg); + } + + Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes); + auto allocated = + LLVM::AllocaOp::create(rewriter, loc, ptr_ty(ctx), structTy, one, + /*alignment=*/0); + + for (const auto &entry : llvm::enumerate(newArgs)) { + auto index = b.i32_val(entry.index()); + auto fieldPtr = + b.gep(ptr_ty(ctx), structTy, allocated, ArrayRef{zero, index}); + b.store(entry.value(), fieldPtr); + } + bufferPtr = b.bitcast(allocated, ptr); + } + + SmallVector operands{formatStrStart, bufferPtr}; + b.call(funcOp, operands); +} + +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + ArrayRef isSigned) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, isSigned); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto funcOp = getAssertfailDeclaration(rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + llvm::SmallString<64> messageString(message), fileString(file), + funcString(func); + messageString.push_back('\0'); + fileString.push_back('\0'); + funcString.push_back('\0'); + Value messageStringVal = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", messageString); + Value fileStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFile_", fileString); + Value funcStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFunc_", funcString); + Value lineNumber = b.i32_val(line); + Value charSize = b.int_val(sizeof(size_t) * 8, sizeof(char)); + SmallVector operands = {messageStringVal, fileStringVal, lineNumber, + funcStringVal, charSize}; + b.call(funcOp, operands); +} + +int TargetInfo::getSharedAddressSpace() const { return 3; } + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + int spaceId = 0; + if (isa(addressSpace)) { + spaceId = 3; + } else { + llvm::report_fatal_error( + "Only support SharedMemorySpace, TensorMemorySpace for now"); + } + return spaceId; +} + +bool TargetInfo::supportVectorizedAtomics() const { + return computeCapability >= 90 && ptxVersion >= 81; +} + +} // namespace mlir::triton::NVIDIA diff --git a/third_party/mthreads/proton/Dialect/lib/compat/Utility.cpp b/third_party/mthreads/proton/Dialect/lib/compat/Utility.cpp new file mode 100644 index 0000000000..95c2e179ee --- /dev/null +++ b/third_party/mthreads/proton/Dialect/lib/compat/Utility.cpp @@ -0,0 +1,439 @@ +#include "compat/Utility.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "third_party/mthreads/include/triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace LLVM { +namespace NVIDIA { +using namespace mlir::triton; + +static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = b.bitcast(val, vecTy); + Value val0 = b.extract_element(f32_ty, vec, b.i32_val(0)); + Value val1 = b.extract_element(f32_ty, vec, b.i32_val(1)); + val0 = shuffleCommonImpl(loc, rewriter, val0, i, mode, clamp); + val1 = shuffleCommonImpl(loc, rewriter, val1, i, mode, clamp); + vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, val0, b.i32_val(0)); + vec = b.insert_element(vecTy, vec, val1, b.i32_val(1)); + return b.bitcast(vec, val.getType()); + } + Type type = val.getType(); + if (type != i32_ty) { + val = b.bitcast(val, int_ty(bits)); + if (bits < 32) + val = b.zext(i32_ty, val); + } + Value mask = b.i32_val(0xFFFFFFFF); + Value result = NVVM::ShflOp::create(rewriter, loc, i32_ty, mask, val, i, + clamp, mode, UnitAttr()); + if (type != i32_ty) { + if (bits < 32) + result = b.trunc(int_ty(bits), result); + result = b.bitcast(result, type); + } + return result; +} + +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // To shuffle pointers, convert them to i64. + Type valTy = val.getType(); + if (isa(valTy)) + val = b.ptrtoint(i64_ty, val); + Value result = shuffleCommonImpl(loc, rewriter, val, i, mode, clamp); + if (isa(valTy)) + result = b.inttoptr(valTy, result); + return result; +} + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::bfly, + b.i32_val(0x1f)); +} + +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::up, + b.i32_val(0x0)); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleIdx(loc, rewriter, val, b.i32_val(i)); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, + b.i32_val(0x1f)); +} + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + ProgramIDDim axis) { + assert(moduleOp); + + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetProgramIdOp. If numCTAs = 1, then + // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to + // "%clusterid". + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + if (numCTAs == 1) { + switch (axis) { + case ProgramIDDim::X: + return NVVM::BlockIdXOp::create(rewriter, loc, i32_ty); + case ProgramIDDim::Y: + return NVVM::BlockIdYOp::create(rewriter, loc, i32_ty); + case ProgramIDDim::Z: + return NVVM::BlockIdZOp::create(rewriter, loc, i32_ty); + } + } else { + switch (axis) { + case ProgramIDDim::X: + return NVVM::ClusterIdXOp::create(rewriter, loc, i32_ty); + case ProgramIDDim::Y: + return NVVM::ClusterIdYOp::create(rewriter, loc, i32_ty); + case ProgramIDDim::Z: + return NVVM::ClusterIdZOp::create(rewriter, loc, i32_ty); + } + } + llvm_unreachable("invalid axis"); +} + +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value selector) { + Value args[] = {a, b, selector}; + auto op = + createLLVMIntrinsicCallOp(rewriter, loc, "llvm.nvvm.prmt", i32_ty, args); + return op.getResult(0); +} + +/// Create a predicate with just single active thread. +Value createElectPredicate(Location loc, OpBuilder &rewriter) { + return NVVM::ElectSyncOp::create(rewriter, loc, i1_ty, + /*membermask=*/Value()); +} + +void createSyncWarp(Location loc, OpBuilder &rewriter) { + TritonLLVMOpBuilder b(loc, rewriter); + NVVM::SyncWarpOp::create(rewriter, loc, b.i32_val(0xffffffff)); +} + +Value createElectPredicateWarp0(Location loc, OpBuilder &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value warpId = getLaneAndWarpId(rewriter, loc).second; + Value warp0 = b.icmp_eq(warpId, b.i32_val(0)); + return b.and_(warp0, createElectPredicate(loc, rewriter)); +} + +Value createTMAMulticastMask(Location loc, ConversionPatternRewriter &rewriter, + uint16_t broadcastBits) { + int numCTAs = triton::gpu::lookupNumCTAs(rewriter); + int blockBits = llvm::Log2_32(numCTAs); + uint32_t fixedBits = (~broadcastBits) & (numCTAs - 1); + uint32_t pattern = 1; + for (int i = 0; i < blockBits; ++i) { + if ((fixedBits & (1u << i)) == 0) + pattern |= (pattern << (1u << i)); + } + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); + Value base = b.and_(ctaId, b.i32_val(fixedBits)); + return b.shl(b.i32_val(pattern), base); +} + +LogicalResult lowerLdStMatrix( + Location loc, LinearLayout cvt, bool transpose, + SmallVector &vals, // Input for stmatrix, output for ldmatrix + Value smemBase, Value affineOffset, uint64_t maskSpanAffineOffset, + Type llvmElemTy, ConversionPatternRewriter &rewriter, + const ::triton::NVIDIA::TargetInfo &targetInfo) { + // Lower load via ldmatrix, store via stmatrix + + bool isStore = !vals.empty(); + if (isStore && !targetInfo.supportStMatrix()) + return failure(); + if (!isStore && !targetInfo.supportLdMatrix()) + return failure(); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + + auto S = [ctx](StringRef v) { return StringAttr::get(ctx, v); }; + auto kReg = S("register"); + auto kLane = S("lane"); + auto kWarp = S("warp"); + auto kBlock = S("block"); + auto kOffset = S("offset"); + auto kAddr = S("addr"); + auto smemPtrTy = ptr_ty(ctx, 3); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + // In the contiguous case we can pack elements <= 32 bits + // In the transpose case we just have the b8 and b16 cases + if ((!transpose && bitwidth > 32) || + (transpose && !(bitwidth == 16 || + (bitwidth == 8 && targetInfo.supportLdStMatrixB8())))) + return failure(); + // Inter block stmatrix is not supported + if (cvt.hasInDim(kBlock)) + return failure(); + + // Map onto offsets (contiguous part) and addr (non-contiguous part) + LinearLayout fullTile; + // Contiguous tile + LinearLayout tile; + // Just used in the transpose case + ColumnAction permLanes, permReg; + // Accumulate the permutations to apply the inverse for loads + ColumnAction accPermReg = + ColumnAction::identity(kReg, cvt.getInDimSizeLog2(kReg)); + if (!transpose) { + tile = LinearLayout::identity1D(32 / bitwidth, kReg, kOffset) * + LinearLayout::identity1D(4, kLane, kOffset); + fullTile = tile * LinearLayout::identity1D(8, kLane, kAddr); + } else { + // We permute the lanes and registers of the layout to the front as to be + // able to divideLeft by the relevant tile + + // Thank you PTX + auto contigRegs = (isStore && bitwidth == 8 ? 16 : 32) / bitwidth; + fullTile = LinearLayout::identity1D(contigRegs, kReg, kAddr) * + LinearLayout::identity1D(4, kLane, kAddr) * + LinearLayout::identity1D(8, kLane, kOffset) * + LinearLayout::identity1D(16 / bitwidth, kReg, kOffset); + // Not enough registers to cover the full tile + if (cvt.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) { + return failure(); + } + // Move offset to the front + std::vector regBases, laneBases; + auto bases = fullTile.invert().getBases().lookup(kOffset); + for (const auto &basis : bases) { + assert(basis.size() == 2); + if (basis[0] != 0) { + regBases.push_back(llvm::Log2_32(basis[0])); + } else { + laneBases.push_back(llvm::Log2_32(basis[1])); + } + } + // quadratic but who cares + for (int i = 0; i < cvt.getInDimSizeLog2(kReg); i++) { + if (!llvm::is_contained(regBases, i)) { + regBases.push_back(i); + } + } + for (int i = 0; i < cvt.getInDimSizeLog2(kLane); i++) { + if (!llvm::is_contained(laneBases, i)) { + laneBases.push_back(i); + } + } + assert(laneBases == std::vector({2, 3, 4, 0, 1})); + // Register depends on our beloved contigRegs + permReg = ColumnAction(regBases, kReg, cvt.getInDimSizeLog2(kReg)); + permLanes = ColumnAction(laneBases, kLane, cvt.getInDimSizeLog2(kLane)); + cvt = permReg.apply(cvt); + cvt = permLanes.apply(cvt); + if (isStore) { + vals = permReg.apply(vals); + } else { + accPermReg = accPermReg.leftCompose(permReg); + } + + // This is the same as permuting the lanes and registers to the front in + // fullTile and taking the kOffset sublayout. + tile = (LinearLayout::identity1D(8, kLane, kOffset) * + LinearLayout::identity1D(16 / bitwidth, kReg, kOffset)) + .transposeIns({kReg, kLane}); + } + + // Find if there is a register permutation that allows us to divideLeft + ColumnAction permDivide; + if (auto maybePermutation = regPermForDivide(cvt, tile, /*left=*/true)) { + permDivide = maybePermutation.value(); + } else { + return failure(); + } + + cvt = permDivide.apply(cvt); + if (isStore) { + vals = permDivide.apply(vals); + } else { + accPermReg = accPermReg.leftCompose(permDivide); + } + auto maybeQuot = divideLeft(cvt, tile); + if (!maybeQuot.has_value()) { + return failure(); + } + + // From here on we perform the lowering + auto reps = zerosLike(tile) * maybeQuot.value(); + + // We revert all the permutations that we performed to be able to divideLeft + if (transpose) { + reps = permLanes.inverse().apply(reps); + reps = permReg.inverse().apply(reps); + if (isStore) { + vals = permReg.inverse().apply(vals); + } else { + accPermReg = accPermReg.leftCompose(permReg.inverse()); + } + } + // Sanity check (of the asymmetry between ldmatrix.b8 and stmatrix.b8): + // All the instructions move 32 bytes of data on .x1 but ldmatrix.b8 which + // moves 64 bytes... + auto regsPerCoreTile = fullTile.getInDimSize(kReg); + assert(regsPerCoreTile * bitwidth == + ((!isStore && bitwidth == 8 && transpose) ? 64 : 32)); + + // If we are lowering a subslice, the subslice offsets shall not touch the + // contiguous part of the tile + if (maskSpanAffineOffset & (tile.getOutDimSizeLog2(kOffset) - 1)) { + return failure(); + } + + // Choose the vectorisation factor + // We want to send at most 128 bits of data per thread as that's the maximum + // vectorisation for all the instructions (even the weird ldmatrix.b8) + auto vec = std::min(128 / bitwidth, reps.getInDimSize(kReg)) / + regsPerCoreTile; + assert(vec == 1 || vec == 2 || vec == 4); + auto fullTileVec = fullTile * LinearLayout::identity1D(vec, kReg, kAddr); + // just add warps as compose belowe requires the dimensions of both layouts to + // agree + fullTileVec *= LinearLayout::identity1D(1, kWarp, kAddr); + // fullTile.invert() is a map from kOffset, kAddr into kReg, kLane, kWarp + // addrToOffset gives us a map from kAddr into kOffset, which is the map of + // the addresses each lane should hold + auto addrToOffset = fullTileVec.invert().compose(reps); + // sanity check + assert(addrToOffset.getInDimSizeLog2(kAddr) >= 3 && + addrToOffset.getInDimSizeLog2(kAddr) <= 5); + + LinearLayout addrLayout = + LinearLayout({{kLane, addrToOffset.getBases().lookup(kAddr)}, + {kWarp, reps.getBases().lookup(kWarp)}}, + {{kOffset, reps.getOutDimSize(kOffset)}}, false); + // Compute the bits that are moved by one instruction + // Compute elements for which we can swap the xor by an add + auto [nAdditive, permStrides] = + actionAdditiveStrides(reps, addrLayout, maskSpanAffineOffset); + reps = permStrides.apply(reps); + if (isStore) { + vals = permStrides.apply(vals); + } else { + accPermReg = accPermReg.leftCompose(permStrides); + } + + // PTX expects the address increments to be done in bytes + // If we don't perform the computations in i8, the compiler would + // have to divide the computation by bitwdith / 8 and then lift this + // shl, which often it's not able to do. + // Adding a kReg dimension is a convenient hack. + // We should just multiply all the bases by bitwidth / 8 + // and then remove the kReg dimension. + assert(bitwidth >= 8); + auto i8Tile = + zerosLike(LinearLayout::identity1D(bitwidth / 8, kReg, kOffset)); + auto i8AddrLayout = i8Tile * addrLayout; + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + auto regBase = + applyLinearLayout( + loc, rewriter, i8AddrLayout, + {{kReg, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}})[0] + .second; + + // It's fine that we don't compute the offset in bytes as affineOffset + // will be folded into a constant + auto affineOffsetI8 = b.mul(affineOffset, b.i32_val(bitwidth / 8)); + regBase = b.xor_(regBase, affineOffsetI8); + + // Instruction params + auto layout = transpose ? NVVM::MMALayout::col : NVVM::MMALayout::row; + auto eltType = transpose && bitwidth == 8 ? NVVM::LdStMatrixEltType::B8 + : NVVM::LdStMatrixEltType::B16; + int m = fullTile.getOutDimSize(kAddr); + int n = fullTile.getOutDimSize(kOffset) * bitwidth / + (eltType == NVVM::LdStMatrixEltType::B8 ? 8 : 16); + if (transpose) { + std::swap(m, n); + } + auto shape = NVVM::LdStMatrixShapeAttr::get(ctx, m, n); + + // Elements per op + auto elemsPerInstr = fullTileVec.getInDimSize(kReg); + auto elemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(llvmElemTy, elemsPerVec); + for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) { + auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; + auto regIdxI8 = regIdx * (bitwidth / 8); + Value offset = b.xor_(regBase, b.i32_val(regIdxI8)); + for (int i2 = 0; i2 < nAdditive; i2 += elemsPerInstr) { + // all these constants will go as immediate values to LDSM/STSM + auto regIdxAdd = + reps.apply({{kReg, i2}, {kLane, 0}, {kWarp, 0}})[0].second; + auto regIdxAddI8 = regIdxAdd * (bitwidth / 8); + Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8)); + auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset, + LLVM::GEPNoWrapFlags::inbounds); + if (isStore) { + // Pack into vector of i32 + SmallVector inputs; + for (int j = 0; j < elemsPerInstr; j += elemsPerVec) { + Value input = b.undef(vecTy); + for (int k = 0; k < elemsPerVec; k++) { + input = b.insert_element(vecTy, input, vals[i + i2 + j + k], + b.i32_val(k)); + } + inputs.push_back(b.bitcast(input, i32_ty)); + } + NVVM::StMatrixOp::create(rewriter, loc, vecAddr, inputs, layout, shape, + eltType); + } else { + unsigned numLdMatrix = elemsPerInstr / elemsPerVec; + assert(numLdMatrix > 0 && + "ldmatrix must load at least one 8x8 tile per instruction"); + Type ldResultTy = + elemsPerInstr == elemsPerVec + ? i32_ty + : static_cast(LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(numLdMatrix, i32_ty))); + auto res = NVVM::LdMatrixOp::create(rewriter, loc, ldResultTy, vecAddr, + vec, layout, shape, eltType) + .getResult(); + // Extract result into srcVals + for (int j = 0; j < elemsPerInstr / elemsPerVec; j++) { + Value output = elemsPerInstr == elemsPerVec + ? res + : b.extract_val(i32_ty, res, j); + output = b.bitcast(output, vecTy); + for (int k = 0; k < elemsPerVec; k++) { + vals.push_back(b.extract_element(llvmElemTy, output, b.i32_val(k))); + } + } + } + } + } + if (!isStore) { + // apply all the inverse permutations in the reverse order + assert(vals.size() == cvt.getInDimSize(kReg)); + vals = accPermReg.inverse().apply(vals); + } + return success(); +} +} // namespace NVIDIA +} // namespace LLVM +} // namespace mlir diff --git a/third_party/mthreads/proton/Dialect/triton_proton.cc b/third_party/mthreads/proton/Dialect/triton_proton.cc new file mode 100644 index 0000000000..580be2dfb8 --- /dev/null +++ b/third_party/mthreads/proton/Dialect/triton_proton.cc @@ -0,0 +1,118 @@ +#include "Analysis/ScopeIdAllocation.h" +#include "Conversion/ProtonGPUToLLVM/Passes.h" +#if TRITON_ENABLE_AMD +#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h" +#endif +#include "Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h" +#include "Conversion/ProtonToProtonGPU/Passes.h" +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/ProtonGPU/IR/Dialect.h" +#include "Dialect/ProtonGPU/Transforms/Passes.h" +#include "ir.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; +using namespace mlir::triton; + +void init_triton_proton(py::module &&m) { + m.doc() = "Python bindings to the Proton backend"; + + // Proton enums + py::enum_(m, "METRIC_TYPE", py::module_local()) + .value("CYCLE", proton::MetricType::CYCLE) + .export_values(); + + py::enum_(m, "SAMPLING_STRATEGY", + py::module_local()) + .value("NONE", proton::SamplingStrategy::NONE) + .value("SELECTIVE", proton::SamplingStrategy::SELECTIVE) + .export_values(); + + // ProtonGPU enums + py::enum_(m, "GRANULARITY", py::module_local()) + .value("CTA", proton::gpu::Granularity::CTA) + .value("WARP", proton::gpu::Granularity::WARP) + .value("WARP_2", proton::gpu::Granularity::WARP_2) + .value("WARP_4", proton::gpu::Granularity::WARP_4) + .value("WARP_8", proton::gpu::Granularity::WARP_8) + .value("WARP_GROUP", proton::gpu::Granularity::WARP_GROUP) + .value("WARP_GROUP_2", proton::gpu::Granularity::WARP_GROUP_2) + .value("WARP_GROUP_4", proton::gpu::Granularity::WARP_GROUP_4) + .value("WARP_GROUP_8", proton::gpu::Granularity::WARP_GROUP_8) + .export_values(); + + py::enum_(m, "BUFFER_STRATEGY", + py::module_local()) + .value("CIRCULAR", proton::gpu::BufferStrategy::CIRCULAR) + .value("FLUSH", proton::gpu::BufferStrategy::FLUSH) + .export_values(); + + py::enum_(m, "BUFFER_TYPE", py::module_local()) + .value("SHARED", proton::gpu::BufferType::SHARED) + .value("GLOBAL", proton::gpu::BufferType::GLOBAL) + .export_values(); + + // Load proton dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + m.def("get_scope_id_names", [](mlir::ModuleOp &module) { + return proton::ModuleScopeIdAllocation(module).getScopeIdNames(); + }); + + m.def("get_scope_id_parents", [](mlir::ModuleOp &module) { + return proton::ModuleScopeIdAllocation(module).getScopeIdParents(); + }); + + // Proton operations + m.def("create_proton_record", + [](TritonOpBuilder &opBuilder, bool isStart, + const std::string &name) -> void { + auto nameAttr = mlir::StringAttr::get(opBuilder.getContext(), + llvm::StringRef(name)); + opBuilder.create(isStart, nameAttr); + }); + + m.def("add_convert_proton_to_protongpu", + [](mlir::PassManager &pm, proton::MetricType &metricType, + proton::SamplingStrategy samplingStrategy, + const std::string &samplingOptions, + proton::gpu::Granularity granularity, + proton::gpu::BufferStrategy bufferStrategy, + proton::gpu::BufferType bufferType, int32_t bufferSize, + int32_t maxSharedMemSize, int64_t profileScratchSize, + int32_t profileScratchAlignment, bool clkExt) { + pm.addPass(proton::createConvertProtonToProtonGPUPass( + metricType, samplingStrategy, samplingOptions, granularity, + bufferStrategy, bufferType, bufferSize, maxSharedMemSize, + profileScratchSize, profileScratchAlignment, clkExt)); + }); + + ADD_PASS_WRAPPER_0("add_convert_proton_nvidia_gpu_to_llvm", + proton::gpu::createConvertProtonNvidiaGPUToLLVMPass); +#if TRITON_ENABLE_AMD + ADD_PASS_WRAPPER_1("add_convert_proton_amd_gpu_to_llvm", + proton::gpu::createConvertProtonAMDGPUToLLVMPass, + const std::string &); +#endif + ADD_PASS_WRAPPER_0("add_allocate_proton_shared_memory", + proton::gpu::createAllocateProtonSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_allocate_proton_global_scratch_buffer", + proton::gpu::createAllocateProtonGlobalScratchBufferPass); + ADD_PASS_WRAPPER_0("add_schedule_buffer_store", + proton::gpu::createScheduleBufferStorePass); +#if TRITON_ENABLE_AMD + ADD_PASS_WRAPPER_0("add_sched_barriers", + proton::gpu::createAddSchedBarriersPass); +#endif +} diff --git a/third_party/mthreads/proton/README.md b/third_party/mthreads/proton/README.md new file mode 100644 index 0000000000..047f37dbe6 --- /dev/null +++ b/third_party/mthreads/proton/README.md @@ -0,0 +1,424 @@ +# Proton - A Profiler for Triton + +## Introduction + +Proton is a lightweight profiler for Triton that captures rich information about program context, metadata, and GPU kernel performance metrics, while keeping both runtime overhead and profile size minimal. + +## Installation + +The following command installs the latest version of Proton. + +```bash +git clone https://github.com/triton-lang/triton +cd triton/python +pip install . +``` + +To **not build** Proton, you can set the `TRITON_BUILD_PROTON` environment variable to `OFF`: + +```bash +TRITON_BUILD_PROTON=OFF pip install . +``` + +## Usage + +### Basic usage + +More examples can be found in the [tutorials](tutorials) directory. + +Proton can be used to profile *functions* and *regions* in Python code. + +- The following examples demonstrate how to use Proton to profile a simple Python function. + +```python +import triton.profiler as proton + +# name: The path to the profile data +# context: The method used to annotate the context of each GPU kernel. Currently, "shadow" and "python" are supported. +session_id = proton.profile(func, name="profile_name", context="python")(args) +``` + +- The following examples demonstrate how to use Proton to profile a region in Python code. + +```python +session_id = proton.start(name="profile_name", context="python") +... +# Skip a region +proton.deactivate(session_id) +... +# Restart profiling +proton.activate(session_id) +... +# Write out the profile data and finalize the profiler +proton.finalize() +``` + +### Scope + +Unlike the *python* context that provide users with files, functions, and lines where the GPU kernels are invoked, the *shadow* context provides users with the annotated regions in the code. The following example demonstrates how to use the *shadow* context. + +```python +import triton.profiler as proton + + +session_id = proton.start(name="profile_name", context="shadow") + +with proton.scope("test0"): + with proton.scope("test1"): + foo[1,](x, y) +with proton.scope("test2"): + foo[1,](x, y) + +... +proton.finalize() +``` + +The *scope* utility also accepts flexible metrics, provided with a dictionary that maps from a string (metric name) to a value (int, float, or a scalar (0-d) tensor). +Proton will aggregate the metrics for each scope and write them to the profile data. +It is useful for users to understand the performance of the model at a high level. + +```python +with proton.scope("test0", {"bytes": 1000}): + with proton.scope("test1", {"bytes": 2000}): + foo[1,](x, y) +with proton.scope("test2", {"bytes": 3000}): + foo[1,](x, y) +``` + +#### NVTX compatibility + +Proton scopes coexist with NVTX ranges. +NVTX pushes and pops (for example, `torch.cuda.nvtx.range_push`) appear as nested scopes in the Proton profile, letting you correlate custom NVTX annotations with Proton's aggregated metrics. + +### Backend and mode + +Proton supports three profiling backends: `cupti`, `roctracer`, and `instrumentation`. + +- **`cupti`**: Used for NVIDIA GPUs. It supports both the default profiling mode and `pcsampling` (instruction sampling). +- **`roctracer`**: Used for AMD GPUs. It supports only the default profiling mode. +- **`instrumentation`**: Available on both NVIDIA and AMD GPUs, this backend enables collection of custom metrics and advanced instrumentation. + +By default, Proton automatically selects either `cupti` or `roctracer` as the backend based on your GPU driver. The `instrumentation` backend offers a wide range of mode options for fine-grained profiling, as detailed in the `mode.py` file. + +#### Instruction sampling + +Proton supports instruction sampling on NVIDIA GPUs. +You may experience ~20x end-to-end overhead when using instruction sampling, although the overhead for each individual GPU kernel is negligible. +The overhead is mostly caused by data transfer and processing on the CPU. +Additionally, the proton-viewer options `-i -d -t ` can be helpful for filtering out GPU kernels that are not of interest. +The following example demonstrates how to use instruction sampling: + +```python +import triton.profiler as proton + +proton.start(name="profile_name", context="shadow", backend="cupti", mode="pcsampling") +``` + +#### Instrumentation + +The instrumentation backend allows for detailed, fine-grained profiling of intra-kernel behavior, generating trace or tree views similar to those produced by coarse-grained profiling. +By default, if no `mode` is specified, Proton profiles kernel cycles, which may require shared memory or global memory (depends on `buffer-type`). If there is insufficient profiling memory capacity, profiling will abort and a warning will be displayed. Future releases will introduce additional instrumentation modes. See the [tutorial](tutorials/intra_kernel) for more detailed information and examples. + +**Host-side usage:** + +```python +import triton.profiler as proton + +proton.start( + name="profile_name", + backend="instrumentation", + mode="=:=:..." +) + +# or + +import triton.profiler.mode as pmode + +proton.start( + name="profile_name", + backend="instrumentation", + mode=pmode.Default() # collect metrics from every warp +) +``` + +**Kernel-side usage:** + +**Caution**: For DSL level instrumentation, **only Gluon** semantic is enabled by default. +Instrumenting kernels written in Triton DSL is disable because Triton's higher-level IR undergoes +aggressive compiler rewrites (loop pipelining, instruction re-ordering, IR duplication, etc.). +These transformations can invalidate naïve instrumentation and lead to misleading results. +To enable instrumentation for Triton DSL, call `pl.enable_semantic("triton")` before `proton.start`. + +```python +from triton.experimental import gluon +from triton.experimental.gluon import language as gl + +import triton.profiler.language as pl + +@gluon.jit +def kernel(...): + pl.enter_scope("scope0") + for i in range(iters): + gl.load(...) + pl.exit_scope("scope0") + with pl.scope("scope1"): + for i in range(iters): + gl.load(...) +``` + +Advanced users can instrument either the `ttir` or `ttgir` intermediate representations for even finer-grained measurement. The relevant IR instructions are `proton.record start` and `proton.record end`. This can be combined with the environment variable `TRITON_KERNEL_OVERRIDE=1` for custom kernel overrides. For detailed steps, refer to the Triton [documentation](https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking) under the **Kernel Override Steps** section. We have also assembled a [tutorial](tutorials/intra_kernel) that demonstrates how to use the IR-based instrumentation approach and the proton DSL approach. + +### Hook + +```python +import triton.profiler as proton +from typing import NamedTuple + +# hook: When hook="triton", it enables proton to invoke launch_metadata function before launching the GPU kernel +proton.start("profile_name", hook="triton") + +def metadata_fn( + grid: tuple, + metadata: NamedTuple, + args: dict +): + return {"name": "", "flops8": 1.0} + +@triton.jit(launch_metadata=metadata_fn) +def foo(x, y): + tl.store(y, tl.load(x)) +``` + +The `metadata_fn` function is called before launching the GPU kernel to provide metadata for the GPU kernel, which returns a dictionary that maps from a string (metadata name) to a value (int or float). + +Currently, **only the launch hook is supported**. In the dictionary returned by the `metadata_fn` function, we can supply the following keys: + +```python +name: str # The name of the kernel +flops8: float # The number of 8-bit floating-point operations +flops16: float # The number of 16-bit floating-point operations +flops32: float # The number of 32-bit floating-point operations +flops64: float # The number of 64-bit floating-point operations +bytes: int # The number of bytes expected to be transferred +``` + +### CUDA graph + +Proton supports profiling graph launched kernels on NVIDIA GPUs. + +It uniquely offers two features. +First, it captures and concatenates the call path where the kernel is captured with the call path where it is launched. +Second, it supports aggregating flexible metrics the same way as individually launched kernels without requiring users to change their code. +The only requirement is to initialize profiling before capturing a CUDA graph. +Users can deactivate it after graph capturing if they want to skip some kernels. + +For example: + +```python +import triton.profiler as proton + +proton.start(name="profile_name", context="shadow") +# Capture the CUDA graph +graph = torch.cuda.CUDAGraph() +with torch.cuda.graph(graph): + with proton.scope("graph"): + ... + +proton.deactivate() + +# Launch the CUDA graph +proton.activate() +with proton.scope("graph_launch"): + graph.replay() +proton.finalize() +``` + +We will see call the call path of the kernels launched by the CUDA graph will be like `graph_launch->->graph->kernel_name`. `` is a special scope added by Proton to indicate the boundary between graph capturing and graph launching. + +### Command line + +Proton can be used as a command-line tool to profile Python scripts and Pytest tests. +The following examples demonstrate how to use Proton command-line. +Detailed options can be found by running `proton -h`. + +```bash +proton [options] script.py [script_args] [script_options] +proton [options] pytest [pytest_args] [script_options] +python -m triton.profiler.proton [options] script.py [script_args] [script_options] +proton --instrument=[instrumentation pass] script.py +``` + +When profiling in the command line mode, the `proton.start` and `proton.finalize` functions are automatically called before and after the script execution. Any `proton.start` and `proton.finalize` functions in the script are ignored. Also, in the command line mode, only a single *session* is supported. +Therefore, `proton.deactivate(session_id=1)` is invalid, while `proton.deactivate(session_id=0)` is valid. + +### Visualizing the profile data + +By default, proton profiles are in the *json* format and can be read by *Hatchet*. The following command visualizes the profile data on terminal. + +```bash +pip install llnl-hatchet +proton-viewer -m time/s +``` + +NOTE: `pip install hatchet` does not work because the API is slightly different. + +If you want to dump the entire trace but not just the aggregated data, you should set the data option to `trace` when starting the profiler. + +```python +import triton.profiler as proton + +proton.start(name="profile_name", data="trace") +``` + +The dumped trace will be in the chrome trace format and can be visualized using the `chrome://tracing` tool in Chrome or the [perfetto](https://perfetto.dev) tool. + +In addition visualizing the profile data on terminal through Hatchet. A sorted list of the kernels by the first metric can be done using the --print-sorted flag with proton-viewer + +```bash +proton-viewer -m time/ns,time/% --print-sorted +``` + +More options can be found by running the following command. + +```bash +proton-viewer -h +``` + +## Knobs + +Triton's runtime has a centralized configuration system called *knobs* that controls various features and behaviors, including the following knobs are defined for Proton: + +- `triton.knobs.proton.enable_nvtx` or `TRITON_ENABLE_NVTX` (default: `True`): Whether to enable NVTX ranges in Proton. + +- `triton.knobs.proton.cupti_lib_dir` or `TRITON_CUPTI_LIB_DIR` (default: `/backends/nvidia/lib/cupti`): The directory of the CUPTI library. + +## Advanced features and knowledge + +### Thread management + +We guarantee that any call to `libproton.so`, such as `enter_scope`, is synchronized using explicit locks. +For operations that do not trigger calls to libproton.so—including callbacks to CUDA/HIP APIs—we use separated locks to protect data structures that may be accessed concurrently by multiple threads. +For example, the `enter_op` method in `OpInterface` can be invoked by the main thread that involves triton operators, as well as by helper threads that invoke torch operators. + +### `cpu_timed_scope` + +`cpu_timed_scope` is a utility that wraps `scope` to measure the CPU time of a scope along with other metrics. +The following example demonstrates how to use `cpu_timed_scope`: + +```python +import triton.profiler as proton + +with proton.cpu_timed_scope("test"): + foo[1,](x, y) +``` + +The `cpu_timed_scope` output metric is referred to as `cpu_time`, while `time` represents accelerator (e.g., GPU) time. +The key distinction between `cpu_time` and `time` lies in their inclusivity: `cpu_time` is exclusive, whereas `time` is inclusive. +This difference arises because the time spent on individual kernels represents the smallest measurable time granularity, and each kernel is mutually exclusive. +This exclusivity allows time to be accurately accumulated across parent scopes for `time`. +In contrast, `cpu_time` measures the time within a specific scope. +Since a parent scope encompasses the time spent in its child scopes, summing `cpu_time` from child scope into parent scope would result in double counting. +To visualize both the CPU and GPU time, we can use the following command: + +```bash +proton-viewer -m time/ns,cpu_time/ns +``` + +### Metrics naming + +Custom metrics should follow this format: `metric_name (unit) (type)`. +We prefer no space within the metric name. +`unit` and `type` are optional fields. + +There are three types of metrics in proton: inclusive, exclusive, and property metrics. +By default, a metric is inclusive. +The metric types are distinguished by the suffix of their names. +The following table shows the suffix for each type and its meaning: + +| Suffix | Name | Meaning | +| --- | --- | --- | +| (inc) or "" | Inclusive metric | The metric is accumulated at a scope and can be propagated to the parent scope. | +| (exc) | Exclusive metric | The metric is accumulated at a scope and cannot be propagated to the parent scope. | +| (pty) | Property metric | The metric is a property of the scope and cannot be accumulated or propagated. | + +### State annotation + +In addition to `proton.scope`, we can also customize the call path of each GPU operation using `proton.state`. + +`state` is different from `scope` in several ways: + +1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state. +2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel. +3. State is compatible with both Python and shadow contexts. + +The following example demonstrates a basic use of state: + +```python +with proton.scope("test"): + with proton.state("state0"): + with proton.scope("test0"): + foo0[1,](x, y) + with proton.scope("test1"): + foo1[1,](x, y) +``` + +The call path of `foo1` will be `test->test1->state0`. + +## Proton *vs* Nsight tools + +| Aspect | Proton | Nsight Systems | Nsight Compute | +| --- | --- | --- | --- | +| Runtime overhead | Lower overhead | Higher overhead | Higher overhead | +| Profile size | Compact profiles and traces | Large traces | Large traces | +| Portability | Multi vendor | Nvidia only | Nvidia only | +| Triton insights | Metadata hooks | No hooks | No hooks | +| Metric depth | Lightweight metrics | Timeline metrics | Detailed metrics | + +**Runtime overhead.** Proton typically keeps slowdown below roughly 1.5×, even for workloads with many short-lived kernels, because it collects fewer metrics and registers fewer callbacks. Nsight Systems and Nsight Compute both impose higher overhead, though they behave similarly to Proton on purely GPU-bound workloads. + +**Profile size.** Proton aggregates kernels that share a calling context, so profile files stay compact—sometimes thousands of times smaller than Nsight traces. Both Nsight tools record each GPU kernel individually, which grows traces quickly during long runs. + +**Portability.** Proton already runs on AMD and NVIDIA GPUs and has a roadmap to extend instruction sampling to AMD hardware. Nsight Systems and Nsight Compute target NVIDIA GPUs exclusively. + +**Triton insights.** Proton can register Triton-specific hooks that surface kernel metadata for richer analysis, at the cost of a small extra overhead. Neither Nsight tool offers comparable Triton integration. + +**Metric depth.** Proton emphasizes lightweight metrics and instruction sampling for portability and fast iteration. Nsight Systems focuses on timeline-oriented metrics for NVIDIA GPUs, while Nsight Compute dives deeper into instruction-level details such as memory transactions and access patterns. + +## Known issues + +- Instruction sampling + +If you encounter permission related problems when using instruction sampling, you can lookup this [page](https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters) for help. + +The overhead of instruction sampling on NVIDIA GPUs is about 20x using Proton because we haven't enabled continuous sampling yet. +Continuous sampling can allow for more runtime optimizations, but it makes it more challenging to attribute performance data back to the GPU kernels because: (1) it enables profiling of concurrent kernels, (2) it doesn't allow profiling of time and instruction samples simultaneously, and (3) it works best if we have a separate thread dedicated to attributing instruction samples to the GPU kernels + +- Visible devices on AMD GPUs + +Environment variables such as `HIP_VISIBLE_DEVICES`, and `CUDA_VISIBLE_DEVICES` are not supported on AMD GPUs. Once it's set, we cannot find a valid mapping between the device ID returned by RocTracer and the physical device ID. Instead, `ROCR_VISIBLE_DEVICES` is recommended to be used. + +## Experimental features + +### Get profile data in memory + +Proton provides APIs to get profile data without dumping to files in the `data` module. These APIs are experimental and may change in the future. + +```python +import triton.profiler as proton + +session_id = proton.start(name="profile_name") +... + +# data.get_* APIs do not synchronize the device, so make sure all kernels are finished before calling them +# Usage 1: flush the profile data from the device eagerly and access all data +proton.deactivate(session_id, flushing=True) # with flushing=False, it's not guaranteed that all kernels are finished +# Get a json dictionary +data = proton.data.get_json(session_id) +# Get a msgpack bytes +data_msgpack = proton.data.get_msgpack(session_id) + +# Usage 2: query the phase completion status and access data in the completed phases +if proton.data.is_phase_complete(session_id, phase_id): + data_phase = proton.data.get_json(session_id, phase_id) + proton.data.clear(session_id, phase_id) +``` diff --git a/third_party/mthreads/proton/common/CMakeLists.txt b/third_party/mthreads/proton/common/CMakeLists.txt new file mode 100644 index 0000000000..3ea7a4199b --- /dev/null +++ b/third_party/mthreads/proton/common/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lib) diff --git a/third_party/mthreads/proton/common/include/Device.h b/third_party/mthreads/proton/common/include/Device.h new file mode 100644 index 0000000000..2c02f2b9bd --- /dev/null +++ b/third_party/mthreads/proton/common/include/Device.h @@ -0,0 +1,48 @@ +#ifndef PROTON_COMMON_DEVICE_H_ +#define PROTON_COMMON_DEVICE_H_ + +#include +#include + +namespace proton { + +enum class DeviceType { HIP, CUDA, COUNT }; + +template struct DeviceTraits; + +template <> struct DeviceTraits { + constexpr static DeviceType type = DeviceType::CUDA; + constexpr static const char *name = "CUDA"; +}; + +template <> struct DeviceTraits { + constexpr static DeviceType type = DeviceType::HIP; + constexpr static const char *name = "HIP"; +}; + +struct Device { + DeviceType type; + uint64_t id; + uint64_t clockRate; // khz + uint64_t memoryClockRate; // khz + uint64_t busWidth; + uint64_t numSms; + std::string arch; + + Device() = default; + + Device(DeviceType type, uint64_t id, uint64_t clockRate, + uint64_t memoryClockRate, uint64_t busWidth, uint64_t numSms, + std::string arch) + : type(type), id(id), clockRate(clockRate), + memoryClockRate(memoryClockRate), busWidth(busWidth), numSms(numSms), + arch(arch) {} +}; + +Device getDevice(DeviceType type, uint64_t index); + +const std::string getDeviceTypeString(DeviceType type); + +}; // namespace proton + +#endif // PROTON_COMMON_DEVICE_H_ diff --git a/third_party/mthreads/proton/common/include/TraceDataIO/ByteSpan.h b/third_party/mthreads/proton/common/include/TraceDataIO/ByteSpan.h new file mode 100644 index 0000000000..220347e34f --- /dev/null +++ b/third_party/mthreads/proton/common/include/TraceDataIO/ByteSpan.h @@ -0,0 +1,53 @@ +#ifndef PROTON_COMMON_BYTE_SPAN_H_ +#define PROTON_COMMON_BYTE_SPAN_H_ + +#include +#include +#include +#include + +namespace proton { + +class BufferException : public std::runtime_error { +public: + explicit BufferException(const std::string &message); +}; + +class ByteSpan { +public: + ByteSpan(const uint8_t *data, size_t size); + + // Read methods + uint8_t readUInt8(); + int8_t readInt8(); + uint16_t readUInt16(); + int16_t readInt16(); + uint32_t readUInt32(); + int32_t readInt32(); + uint64_t readUInt64(); + int64_t readInt64(); + + // Buffer navigation + void skip(size_t count); + void seek(size_t position); + size_t position() const { return pos; } + size_t size() const { return dataSize; } + size_t remaining() const { return dataSize - pos; } + bool hasRemaining(size_t count = 0) const { return remaining() >= count; } + + // Data access + const uint8_t *data() const { return dataPtr; } + const uint8_t *currentData() const { return dataPtr + pos; } + +private: + const uint8_t *dataPtr; // Pointer to the underlying data + size_t dataSize; // Total size of the data + size_t pos; // Current read position + + // Helper method to check remaining bytes + void checkRemaining(size_t required) const; +}; + +} // namespace proton + +#endif // PROTON_COMMON_BYTE_SPAN_H_ diff --git a/third_party/mthreads/proton/common/include/TraceDataIO/CircularLayoutParser.h b/third_party/mthreads/proton/common/include/TraceDataIO/CircularLayoutParser.h new file mode 100644 index 0000000000..d15b75f79a --- /dev/null +++ b/third_party/mthreads/proton/common/include/TraceDataIO/CircularLayoutParser.h @@ -0,0 +1,98 @@ +#ifndef PROTON_COMMON_CIRCULAR_LAYOUT_PARSER_H_ +#define PROTON_COMMON_CIRCULAR_LAYOUT_PARSER_H_ + +#include "Parser.h" +#include + +namespace proton { + +constexpr uint32_t kPreamble = 0xdeadbeef; +constexpr uint32_t kHeaderSize = 16; +constexpr uint32_t kWordSize = 4; +constexpr uint32_t kWordsPerEntry = 2; + +enum class ParseState { START, END, INIT }; + +struct CircularLayoutParserConfig : public ParserConfig { + // The total number of unit (e.g., num of warps) in CTA + size_t totalUnits = 0; + // Scratch memory size in bytes per CTA (scratchMemSize = metadata_size + + // bufSize) + size_t scratchMemSize = 0; + // The number of blocks in the grid + size_t numBlocks = 0; + // A vector of trace's uids + std::vector uidVec = {}; +}; + +struct CircularLayoutParserResult { + // start cycle entry and end cycle entry + using ProfileEvent = + std::pair, std::shared_ptr>; + + struct Trace { + uint32_t uid = 0; + + // Total count of words (i32) if we don't drop events. + uint32_t count = 0; + + std::vector profileEvents; + }; + + struct BlockTrace { + uint32_t blockId = 0; + uint32_t procId = 0; + uint32_t bufSize = 0; + uint64_t initTime = 0; + uint64_t preFinalTime = 0; + uint64_t postFinalTime = 0; + std::vector traces; + }; + + std::vector blockTraces; +}; + +class CircularLayoutParser : public ParserBase { +public: + explicit CircularLayoutParser(ByteSpan &buffer, + const CircularLayoutParserConfig &config); + + void parse() final; + + const CircularLayoutParserConfig &getConfig() const override; + + std::shared_ptr getResult(); + +private: + void parseMetadata(); + void parseProfileEvents(); + void parseSegment(int byteSize, CircularLayoutParserResult::Trace &trace); + void parseBlock(); + + std::shared_ptr result = nullptr; + EntryDecoder decoder; +}; + +struct PreambleException : public ParserException { + PreambleException(const std::string &msg); +}; + +struct ScopeMisMatchException : public ParserException { + ScopeMisMatchException(const std::string &msg); +}; + +struct ClockOverflowException : public ParserException { + ClockOverflowException(const std::string &msg); +}; + +std::shared_ptr +readCircularLayoutTrace(ByteSpan &buffer, bool applyTimeShift = false); + +uint64_t getTimeShiftCost(const CircularLayoutParserConfig &config); + +void timeShift(const uint64_t cost, + std::shared_ptr result); + +} // namespace proton + +#endif // PROTON_COMMON_CIRCULAR_LAYOUT_PARSER_H_ diff --git a/third_party/mthreads/proton/common/include/TraceDataIO/EntryDecoder.h b/third_party/mthreads/proton/common/include/TraceDataIO/EntryDecoder.h new file mode 100644 index 0000000000..ae3fe5e92a --- /dev/null +++ b/third_party/mthreads/proton/common/include/TraceDataIO/EntryDecoder.h @@ -0,0 +1,77 @@ +#ifndef PROTON_COMMON_ENTRY_DECODER_H_ +#define PROTON_COMMON_ENTRY_DECODER_H_ + +#include "ByteSpan.h" +#include +#include +#include + +namespace proton { + +class EntryBase; + +template void decodeFn(ByteSpan &buffer, EntryT &entry) { + throw std::runtime_error("No decoder function is implemented"); +} + +class EntryDecoder { +private: + ByteSpan &buf; + +public: + explicit EntryDecoder(ByteSpan &buffer) : buf(buffer) {} + + template std::shared_ptr decode() { + auto entry = std::make_shared(); + decodeFn(buffer(), *entry); + return entry; + } + +protected: + // Protected accessor for the buffer + ByteSpan &buffer() { return buf; } +}; + +struct EntryBase { + virtual ~EntryBase() = default; + + virtual void print(std::ostream &os) const = 0; +}; + +std::ostream &operator<<(std::ostream &os, const EntryBase &obj); + +struct I32Entry : public EntryBase { + I32Entry() = default; + + void print(std::ostream &os) const override; + + int32_t value = 0; +}; + +template <> void decodeFn(ByteSpan &buffer, I32Entry &entry); + +struct I64Entry : public EntryBase { + I64Entry() = default; + + void print(std::ostream &os) const override; + + int64_t value = 0; +}; + +template <> void decodeFn(ByteSpan &buffer, I64Entry &entry); + +struct CycleEntry : public EntryBase { + CycleEntry() = default; + + void print(std::ostream &os) const override; + + uint64_t cycle = 0; + bool isStart = true; + int32_t scopeId = 0; +}; + +template <> void decodeFn(ByteSpan &buffer, CycleEntry &entry); + +} // namespace proton + +#endif // PROTON_COMMON_ENTRY_DECODER_H_ diff --git a/third_party/mthreads/proton/common/include/TraceDataIO/Parser.h b/third_party/mthreads/proton/common/include/TraceDataIO/Parser.h new file mode 100644 index 0000000000..c774e5f41d --- /dev/null +++ b/third_party/mthreads/proton/common/include/TraceDataIO/Parser.h @@ -0,0 +1,58 @@ +#ifndef PROTON_COMMON_PARSER_H_ +#define PROTON_COMMON_PARSER_H_ + +#include "ByteSpan.h" +#include "Device.h" +#include "EntryDecoder.h" +#include +#include + +namespace proton { + +struct ParserConfig { + enum class PrintMode { + SILENT, // Don't print anything + ALL // Print all messages + }; + + // Configure exception message visibility + PrintMode printLevel = PrintMode::SILENT; + + // Device type that generated the trace + Device device; + + virtual ~ParserConfig() = default; +}; + +// Define exception severity levels +enum class ExceptionSeverity { + WARNING, // Continue parsing + ERROR // Stop parsing +}; + +struct ParserException : public std::runtime_error { + ExceptionSeverity severity; + + ParserException(const std::string &msg, ExceptionSeverity sev); +}; + +class ParserBase { +public: + explicit ParserBase(ByteSpan &buffer, const ParserConfig &config); + + virtual ~ParserBase() = default; + + virtual void parse() = 0; + + virtual const ParserConfig &getConfig() const; + +protected: + void reportException(const ParserException &e, size_t pos); + + const ParserConfig &config; + ByteSpan &buffer; +}; + +} // namespace proton + +#endif // PROTON_COMMON_PARSER_H_ diff --git a/third_party/mthreads/proton/common/include/TraceDataIO/TraceWriter.h b/third_party/mthreads/proton/common/include/TraceDataIO/TraceWriter.h new file mode 100644 index 0000000000..9558661f8a --- /dev/null +++ b/third_party/mthreads/proton/common/include/TraceDataIO/TraceWriter.h @@ -0,0 +1,71 @@ +#ifndef PROTON_COMMON_TRACE_WRITER_H_ +#define PROTON_COMMON_TRACE_WRITER_H_ + +#include "CircularLayoutParser.h" +#include "nlohmann/json.hpp" +#include +#include +#include +#include +#include +#include + +namespace proton { + +struct KernelMetadata { + std::map scopeName; + std::string kernelName; + std::vector callStack; +}; + +using KernelTrace = std::pair, + std::shared_ptr>; + +// StreamTraceWriter handles trace dumping for a single cuda stream. +// If we have multiple stream, simply having a for loop to write to multiple +// files (one for each stream). Other types of per-stream trace writers could +// subclass the StreamTraceWriter such as StreamPerfettoTraceWriter that +// produces a protobuf format trace. +class StreamTraceWriter { +public: + explicit StreamTraceWriter(const std::vector &streamTrace, + const std::string &path); + + virtual ~StreamTraceWriter() = default; + + void dump(); + + virtual void write(std::ostream &outfile) = 0; + +protected: + const std::string path; + const std::vector &streamTrace; +}; + +class StreamChromeTraceWriter : public StreamTraceWriter { +public: + explicit StreamChromeTraceWriter(const std::vector &streamTrace, + const std::string &path); + + void write(std::ostream &outfile) override final; + +private: + void writeKernel(nlohmann::json &object, const KernelTrace &kernelTrace, + const uint64_t minInitTime); + + const std::vector kChromeColor = {"cq_build_passed", + "cq_build_failed", + "thread_state_iowait", + "thread_state_running", + "thread_state_runnable", + "thread_state_unknown", + "rail_response", + "rail_idle", + "rail_load", + "cq_build_attempt_passed", + "cq_build_attempt_failed"}; +}; + +} // namespace proton + +#endif // PROTON_COMMON_TRACE_WRITER_H_ diff --git a/third_party/mthreads/proton/common/lib/CMakeLists.txt b/third_party/mthreads/proton/common/lib/CMakeLists.txt new file mode 100644 index 0000000000..5646a3832a --- /dev/null +++ b/third_party/mthreads/proton/common/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TraceDataIO) diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/ByteSpan.cpp b/third_party/mthreads/proton/common/lib/TraceDataIO/ByteSpan.cpp new file mode 100644 index 0000000000..f218c8ff7f --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/ByteSpan.cpp @@ -0,0 +1,77 @@ +#include "TraceDataIO/ByteSpan.h" + +using namespace proton; + +ByteSpan::ByteSpan(const uint8_t *data, size_t size) + : dataPtr(data), dataSize(size), pos(0) { + if (data == nullptr && size > 0) { + throw std::invalid_argument( + "Data pointer cannot be null for non-zero size"); + } +} + +void ByteSpan::checkRemaining(size_t required) const { + if (remaining() < required) { + throw BufferException(""); + } +} + +uint8_t ByteSpan::readUInt8() { + checkRemaining(1); + return dataPtr[pos++]; +} + +int8_t ByteSpan::readInt8() { return static_cast(readUInt8()); } + +uint16_t ByteSpan::readUInt16() { + checkRemaining(2); + uint16_t value = static_cast(dataPtr[pos]) | + (static_cast(dataPtr[pos + 1]) << 8); + pos += 2; + return value; +} + +int16_t ByteSpan::readInt16() { return static_cast(readUInt16()); } + +uint32_t ByteSpan::readUInt32() { + checkRemaining(4); + uint32_t value = static_cast(dataPtr[pos]) | + (static_cast(dataPtr[pos + 1]) << 8) | + (static_cast(dataPtr[pos + 2]) << 16) | + (static_cast(dataPtr[pos + 3]) << 24); + pos += 4; + return value; +} + +int32_t ByteSpan::readInt32() { return static_cast(readUInt32()); } + +uint64_t ByteSpan::readUInt64() { + checkRemaining(8); + uint64_t value = static_cast(dataPtr[pos]) | + (static_cast(dataPtr[pos + 1]) << 8) | + (static_cast(dataPtr[pos + 2]) << 16) | + (static_cast(dataPtr[pos + 3]) << 24) | + (static_cast(dataPtr[pos + 4]) << 32) | + (static_cast(dataPtr[pos + 5]) << 40) | + (static_cast(dataPtr[pos + 6]) << 48) | + (static_cast(dataPtr[pos + 7]) << 56); + pos += 8; + return value; +} + +int64_t ByteSpan::readInt64() { return static_cast(readUInt64()); } + +void ByteSpan::skip(size_t count) { + checkRemaining(count); + pos += count; +} + +void ByteSpan::seek(size_t position) { + if (position > dataSize) { + throw BufferException(""); + } + pos = position; +} + +BufferException::BufferException(const std::string &message) + : std::runtime_error(message) {} diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/CMakeLists.txt b/third_party/mthreads/proton/common/lib/TraceDataIO/CMakeLists.txt new file mode 100644 index 0000000000..cd81e4f932 --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/CMakeLists.txt @@ -0,0 +1,7 @@ +add_proton_library(ProtonTraceDataIO + ByteSpan.cpp + EntryDecoder.cpp + Parser.cpp + CircularLayoutParser.cpp + TraceWriter.cpp +) diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/CircularLayoutParser.cpp b/third_party/mthreads/proton/common/lib/TraceDataIO/CircularLayoutParser.cpp new file mode 100644 index 0000000000..3573921ac3 --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/CircularLayoutParser.cpp @@ -0,0 +1,254 @@ +#include "TraceDataIO/CircularLayoutParser.h" +#include +#include +#include +#include + +using namespace proton; + +CircularLayoutParser::CircularLayoutParser( + ByteSpan &buffer, const CircularLayoutParserConfig &config) + : ParserBase(buffer, config), decoder(buffer) { + result = std::make_shared(); +} + +std::shared_ptr CircularLayoutParser::getResult() { + return result; +} + +void CircularLayoutParser::parse() { + auto &uidVec = getConfig().uidVec; + assert(uidVec.size()); + assert(std::is_sorted(uidVec.begin(), uidVec.end())); + + int numBlocks = getConfig().numBlocks; + const int scratchMemSize = getConfig().scratchMemSize; + uint32_t pos = buffer.position(); + for (int i = 0; i < numBlocks; i++) { + buffer.seek(pos); + parseBlock(); + pos += scratchMemSize; + } +} + +const CircularLayoutParserConfig &CircularLayoutParser::getConfig() const { + return static_cast(config); +} + +void CircularLayoutParser::parseMetadata() { + uint32_t preamble = decoder.decode()->value; + if (preamble != kPreamble) + throw PreambleException("Invalid preamble"); + auto &bt = result->blockTraces.emplace_back(); + bt.blockId = decoder.decode()->value; + bt.procId = decoder.decode()->value; + bt.bufSize = decoder.decode()->value; + bt.initTime = decoder.decode()->value; + bt.preFinalTime = decoder.decode()->value; + bt.postFinalTime = decoder.decode()->value; + + std::vector countVec; + for (int i = 0; i < getConfig().totalUnits; i++) { + countVec.push_back(decoder.decode()->value); + } + + // Each event is 8 bytes + int maxCountPerUnit = bt.bufSize / getConfig().uidVec.size() / 8; + + for (auto uid : getConfig().uidVec) { + // Each event is 2 words (8 bytes) and countVec captures the number of words + // of each warp captured during profiling + auto count = countVec[uid]; + auto numEvent = count / 2; + + if (numEvent > maxCountPerUnit) { + std::cerr << "Warning (cta" << bt.blockId << ", warp" << uid + << "): first " << numEvent - maxCountPerUnit + << " events are dropped due to insufficient buffer size (" + << maxCountPerUnit << "/" << numEvent << ")" << std::endl; + } + + auto &trace = bt.traces.emplace_back(); + trace.uid = uid; + trace.count = count; + } +} + +void CircularLayoutParser::parseProfileEvents() { + auto &bt = result->blockTraces.back(); + const int bufferSize = bt.bufSize; + const int numSegments = getConfig().uidVec.size(); + const int segmentByteSize = bufferSize / numSegments; + auto position = buffer.position(); + for (int i = 0; i < numSegments; i++) { + buffer.seek(position); + auto &trace = bt.traces[i]; + parseSegment(segmentByteSize, trace); + position += segmentByteSize; + } +} + +void CircularLayoutParser::parseSegment( + int segmentByteSize, CircularLayoutParserResult::Trace &trace) { + + auto state = ParseState::INIT; + int idealSize = trace.count * kWordSize; + int byteSize = std::min(idealSize, segmentByteSize); + const int maxNumEntries = byteSize / (kWordSize * kWordsPerEntry); + + std::unordered_map activeEvent; + std::unordered_map scopeState; + + for (int i = 0; i < maxNumEntries; i++) { + try { + auto entry = decoder.decode(); + if (!activeEvent.count(entry->scopeId)) { + activeEvent[entry->scopeId] = + CircularLayoutParserResult::ProfileEvent(); + } + auto &activeProfileEvent = activeEvent[entry->scopeId]; + + auto prevState = ParseState::INIT; + if (scopeState.count(entry->scopeId)) + prevState = scopeState[entry->scopeId]; + + if (entry->isStart) { + if (prevState == ParseState::INIT || prevState == ParseState::END) { + activeProfileEvent.first = entry; + scopeState[entry->scopeId] = ParseState::START; + } else { + throw ScopeMisMatchException("Scope mismatch: start after start"); + } + } else { + if (prevState == ParseState::START) { + activeProfileEvent.second = entry; + scopeState[entry->scopeId] = ParseState::END; + + if (activeProfileEvent.first->cycle > + activeProfileEvent.second->cycle) { + throw ClockOverflowException("Clock overflow"); + } + trace.profileEvents.push_back(activeProfileEvent); + } else { + throw ScopeMisMatchException("Scope mismatch: end after end"); + } + } + } catch (const ScopeMisMatchException &e) { + reportException(e, buffer.position()); + } catch (const ClockOverflowException &e) { + reportException(e, buffer.position()); + } + } +} + +void CircularLayoutParser::parseBlock() { + try { + parseMetadata(); + parseProfileEvents(); + } catch (const PreambleException &e) { + reportException(e, buffer.position()); + } +} + +PreambleException::PreambleException(const std::string &msg) + : ParserException(msg, ExceptionSeverity::ERROR) {} + +ScopeMisMatchException::ScopeMisMatchException(const std::string &msg) + : ParserException(msg, ExceptionSeverity::WARNING) {} + +ClockOverflowException::ClockOverflowException(const std::string &msg) + : ParserException(msg, ExceptionSeverity::ERROR) {} + +namespace { +Device decodeDevice(const uint32_t dev) { + Device device; + switch (dev) { + case 1: + device.type = DeviceType::CUDA; + device.arch = ""; + break; + case 2: + device.type = DeviceType::HIP; + device.arch = ""; + break; + default: + break; + } + return device; +} + +void shift(CircularLayoutParserResult::Trace &trace, const uint64_t cost, + const uint64_t timeBase) { + for (auto &event : trace.profileEvents) { + if (event.first->cycle >= timeBase) + event.first->cycle -= cost; + if (event.second->cycle >= timeBase) + event.second->cycle -= cost; + } +} +} // namespace + +std::shared_ptr +proton::readCircularLayoutTrace(ByteSpan &buffer, bool applyTimeShift) { + CircularLayoutParserConfig config; + auto decoder = EntryDecoder(buffer); + uint32_t version = decoder.decode()->value; + assert(version == 1 && "Version mismatch"); + buffer.skip(8); + uint32_t payloadOffset = decoder.decode()->value; + uint32_t payloadSize = decoder.decode()->value; + uint32_t device = decoder.decode()->value; + config.device = decodeDevice(device); + config.numBlocks = decoder.decode()->value; + config.totalUnits = decoder.decode()->value; + config.scratchMemSize = decoder.decode()->value; + uint32_t uidNum = decoder.decode()->value; + + config.uidVec.clear(); + for (int i = 0; i < uidNum; i++) { + uint32_t uid = decoder.decode()->value; + config.uidVec.push_back(uid); + } + + buffer.seek(payloadOffset); + auto parser = std::make_unique(buffer, config); + parser->parse(); + auto result = parser->getResult(); + + // Shift the clocks to reduce the constant profiling overhead + if (applyTimeShift) { + const uint64_t cost = getTimeShiftCost(config); + timeShift(cost, result); + } + + return result; +} + +void proton::timeShift(const uint64_t cost, + std::shared_ptr result) { + for (auto &bt : result->blockTraces) { + for (auto &trace : bt.traces) { + for (auto &event : trace.profileEvents) { + const uint64_t startTimeBase = event.first->cycle; + shift(trace, cost, startTimeBase); + + const uint64_t endTimeBase = event.second->cycle; + shift(trace, cost, endTimeBase); + + // Adjust the cycle for tiny events below the profiling precision + if (event.second->cycle < event.first->cycle) { + event.second->cycle = event.first->cycle + cost / 2; + } + } + } + } +} + +uint64_t proton::getTimeShiftCost(const CircularLayoutParserConfig &config) { + if (config.device.type == DeviceType::CUDA) + return 7; + else if (config.device.type == DeviceType::HIP) + return 36; + + return 0; +} diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/EntryDecoder.cpp b/third_party/mthreads/proton/common/lib/TraceDataIO/EntryDecoder.cpp new file mode 100644 index 0000000000..94c0ff4332 --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/EntryDecoder.cpp @@ -0,0 +1,34 @@ +#include "TraceDataIO/EntryDecoder.h" + +using namespace proton; + +std::ostream &operator<<(std::ostream &os, const EntryBase &obj) { + obj.print(os); + return os; +} + +void I32Entry::print(std::ostream &os) const { os << value; } + +template <> void proton::decodeFn(ByteSpan &buffer, I32Entry &entry) { + entry.value = buffer.readInt32(); +} + +void I64Entry::print(std::ostream &os) const { os << value; } + +template <> void proton::decodeFn(ByteSpan &buffer, I64Entry &entry) { + entry.value = buffer.readInt64(); +} + +void CycleEntry::print(std::ostream &os) const { + std::string prefix = isStart ? "S" : "E"; + os << prefix + std::to_string(scopeId) + "C" + std::to_string(cycle); +} + +template <> +void proton::decodeFn(ByteSpan &buffer, CycleEntry &entry) { + uint32_t tagClkUpper = buffer.readUInt32(); + entry.isStart = (tagClkUpper & 0x80000000) == 0; + entry.scopeId = (tagClkUpper & 0x7F800000) >> 23; + uint64_t clkLower = buffer.readUInt32(); + entry.cycle = static_cast(tagClkUpper & 0x7FF) << 32 | clkLower; +} diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/Parser.cpp b/third_party/mthreads/proton/common/lib/TraceDataIO/Parser.cpp new file mode 100644 index 0000000000..819db45244 --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/Parser.cpp @@ -0,0 +1,25 @@ +#include "TraceDataIO/Parser.h" + +using namespace proton; + +ParserException::ParserException(const std::string &msg, ExceptionSeverity sev) + : std::runtime_error(msg), severity(sev) {} + +ParserBase::ParserBase(ByteSpan &buffer, const ParserConfig &config) + : buffer(buffer), config(config) {} + +void ParserBase::reportException(const ParserException &e, size_t pos) { + + if (e.severity == ExceptionSeverity::ERROR || + config.printLevel == ParserConfig::PrintMode::ALL) { + std::cerr << "ParserException [offset=" << pos << "]: " << e.what() + << std::endl; + } + + if (e.severity == ExceptionSeverity::WARNING) + return; + + throw e; +} + +const ParserConfig &ParserBase::getConfig() const { return config; } diff --git a/third_party/mthreads/proton/common/lib/TraceDataIO/TraceWriter.cpp b/third_party/mthreads/proton/common/lib/TraceDataIO/TraceWriter.cpp new file mode 100644 index 0000000000..a945aa64bd --- /dev/null +++ b/third_party/mthreads/proton/common/lib/TraceDataIO/TraceWriter.cpp @@ -0,0 +1,249 @@ +#include "TraceDataIO/TraceWriter.h" +#include +#include +#include +#include +#include + +using namespace proton; +using json = nlohmann::json; + +namespace { + +uint64_t getMinInitTime(const std::vector &streamTrace) { + uint64_t minInitTime = std::numeric_limits::max(); + for (const auto &kernelTrace : streamTrace) + for (const auto &bt : kernelTrace.first->blockTraces) { + if (bt.initTime < minInitTime) { + minInitTime = bt.initTime; + } + } + return minInitTime; +} + +} // namespace + +StreamTraceWriter::StreamTraceWriter( + const std::vector &streamTrace, const std::string &path) + : streamTrace(streamTrace), path(path) {} + +void StreamTraceWriter::dump() { + std::ofstream outfile; + + if (path.empty()) { + std::cerr << "Trace file path can't be empty!"; + return; + } + + outfile.open(path); + if (!outfile.is_open()) { + std::cerr << "Failed to open trace file: " << path << std::endl; + return; + } + + write(outfile); + + outfile.close(); +} + +StreamChromeTraceWriter::StreamChromeTraceWriter( + const std::vector &streamTrace, const std::string &path) + : StreamTraceWriter(streamTrace, path) {} + +void StreamChromeTraceWriter::write(std::ostream &outfile) { + if (streamTrace.empty()) { + std::cerr << "Failed to write the trace file: empty trace!" << std::endl; + return; + } + + json object = {{"displayTimeUnit", "ns"}, {"traceEvents", json::array()}}; + + const auto minInitTime = getMinInitTime(streamTrace); + + for (const auto &kernelTrace : streamTrace) { + writeKernel(object, kernelTrace, minInitTime); + } + outfile << object.dump() << "\n"; +} + +namespace { +using BlockTraceVec = + std::vector; + +void populateTraceInfo(std::shared_ptr result, + std::map &blockToMinCycle, + std::map &procToBlockTraces) { + for (auto &bt : result->blockTraces) { + // Find the minimum cycle for each block + uint64_t minCycle = std::numeric_limits::max(); + for (auto &trace : bt.traces) + for (auto &event : trace.profileEvents) + if (event.first->cycle < minCycle) + minCycle = event.first->cycle; + blockToMinCycle[bt.blockId] = minCycle; + + // Group block traces by proc id + int procId = bt.procId; + if (!procToBlockTraces.count(procId)) { + procToBlockTraces[procId] = {}; + } + procToBlockTraces[procId].push_back(&bt); + } +} + +std::vector assignLineIds( + const std::vector &trace) { + + std::vector result(trace.size()); + + if (trace.empty()) { + return result; + } + + // Create indexed events and sort by start time + std::vector> + indexedEvents; + indexedEvents.reserve(trace.size()); + + for (size_t i = 0; i < trace.size(); ++i) { + indexedEvents.push_back({i, trace[i]}); + } + + std::sort(indexedEvents.begin(), indexedEvents.end(), + [](const auto &a, const auto &b) { + return a.second.first->cycle < b.second.first->cycle; + }); + + // For each line, store all the intervals + std::vector>> lines; + + for (const auto &[originalIdx, event] : indexedEvents) { + uint64_t startTime = event.first->cycle; + uint64_t endTime = event.second->cycle; + + // Find the first line where this event can be placed + int lineIdx = 0; + bool foundLine = false; + + for (; lineIdx < lines.size(); ++lineIdx) { + const auto &lineIntervals = lines[lineIdx]; + bool canPlace = true; + + // Check for overlap with any interval on this line + for (const auto &[intervalStart, intervalEnd] : lineIntervals) { + // Check if there's any overlap + if (startTime < intervalEnd && endTime > intervalStart) { + canPlace = false; + break; + } + } + + if (canPlace) { + foundLine = true; + break; + } + } + + // If no suitable line found, create a new one + if (!foundLine) { + lineIdx = lines.size(); + lines.push_back({}); + } + + // Add the event to the line + lines[lineIdx].push_back({startTime, endTime}); + result[originalIdx] = lineIdx; + } + + return result; +} + +} // namespace + +void StreamChromeTraceWriter::writeKernel(json &object, + const KernelTrace &kernelTrace, + const uint64_t minInitTime) { + auto result = kernelTrace.first; + auto metadata = kernelTrace.second; + + json callStack = json::array(); + for (auto const &frame : metadata->callStack) { + callStack.push_back(frame); + } + + int curColorIndex = 0; + // scope id -> color index in chrome color + std::map scopeColor; + // block id -> min cycle observed + std::map blockToMinCycle; + // proc id -> block traces + std::map procToBlockTraces; + + populateTraceInfo(result, blockToMinCycle, procToBlockTraces); + + std::string name; + std::string pid; + std::string category; + std::string tid; + for (auto &[procId, blockVec] : procToBlockTraces) { + for (auto *bt : blockVec) { + int ctaId = bt->blockId; + for (auto &trace : bt->traces) { + int warpId = trace.uid; + auto lineInfo = assignLineIds(trace.profileEvents); + int eventIdx = 0; + for (auto &event : trace.profileEvents) { + int lineId = lineInfo[eventIdx]; + int scopeId = event.first->scopeId; + if (!scopeColor.count(scopeId)) { + scopeColor[scopeId] = curColorIndex; + curColorIndex = (curColorIndex + 1) % kChromeColor.size(); + } + const std::string &color = kChromeColor[scopeColor[scopeId]]; + pid = metadata->kernelName + " Core" + std::to_string(procId) + + " CTA" + std::to_string(ctaId); + tid = "warp " + std::to_string(warpId) + " (line " + + std::to_string(lineId) + ")"; + category = metadata->kernelName; + if (!metadata->scopeName.count(scopeId)) + name = "scope_" + std::to_string(scopeId); + else + name = metadata->scopeName.at(scopeId); + + // Unit: MHz, we assume freq is 1000MHz (1GHz) + double freq = 1000.0; + + // Global time is in `ns` unit. With 1GHz assumption, we + // could subtract with blockToMInCycle: (ns - ns) / 1GHz - cycle + int64_t cycleAdjust = + static_cast(bt->initTime - minInitTime) - + static_cast(blockToMinCycle[ctaId]); + int64_t ts = static_cast(event.first->cycle) + cycleAdjust; + int64_t dur = + static_cast(event.second->cycle) - event.first->cycle; + + json element; + element["cname"] = color; + element["name"] = name; + element["cat"] = category; + element["ph"] = "X"; + element["pid"] = pid; + element["tid"] = tid; + element["ts"] = static_cast(ts) / freq; + element["dur"] = static_cast(dur) / freq; + json args; + args["Init Time (ns)"] = bt->initTime; + args["Post Final Time (ns)"] = bt->postFinalTime; + args["Finalization Time (ns)"] = bt->postFinalTime - bt->preFinalTime; + args["Frequency (MHz)"] = freq; + element["args"] = args; + element["args"]["call_stack"] = callStack; + + object["traceEvents"].push_back(element); + + eventIdx++; + } + } + } + } +} diff --git a/third_party/mthreads/proton/csrc/CMakeLists.txt b/third_party/mthreads/proton/csrc/CMakeLists.txt new file mode 100644 index 0000000000..5eba4657e9 --- /dev/null +++ b/third_party/mthreads/proton/csrc/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(Proton + Proton.cpp +) + +add_subdirectory(lib) diff --git a/third_party/mthreads/proton/csrc/Proton.cpp b/third_party/mthreads/proton/csrc/Proton.cpp new file mode 100644 index 0000000000..1495d842fd --- /dev/null +++ b/third_party/mthreads/proton/csrc/Proton.cpp @@ -0,0 +1,209 @@ +#include "Proton.h" + +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +using namespace proton; + +// For simplicity, the Python interface restricts metrics to int64_t and double. +// without uint64_t. Allowing types such as uint64_t vs. int64_t would force +// users to handle subtle type differences for the same metric name, which would +// be confusing and error-prone. +using PythonMetricValueType = std::variant; +namespace { + +std::map convertPythonMetrics( + const std::map &metrics) { + std::map converted; + for (const auto &[name, value] : metrics) { + converted.emplace(name, std::visit( + [](auto &&v) -> MetricValueType { + return MetricValueType(v); + }, + value)); + } + return converted; +} + +} // namespace + +static void initProton(pybind11::module &&m) { + using ret = pybind11::return_value_policy; + using namespace pybind11::literals; + + // Accept raw integer pointers from Python (e.g., Tensor.data_ptr()) instead + // of requiring a PyCapsule, which matches how tensor metric values are passed + // in transform_tensor_metrics. + pybind11::class_(m, "TensorMetric") + .def(pybind11::init<>()) + .def(pybind11::init([](uintptr_t ptr, size_t index) { + return TensorMetric{reinterpret_cast(ptr), index}; + }), + pybind11::arg("ptr"), pybind11::arg("index")) + .def_property_readonly("ptr", + [](const TensorMetric &metric) { + return reinterpret_cast(metric.ptr); + }) + .def_property_readonly( + "index", [](const TensorMetric &metric) { return metric.index; }); + + m.attr("metric_int64_index") = + pybind11::cast(variant_index_v); + m.attr("metric_double_index") = + pybind11::cast(variant_index_v); + + m.def( + "start", + [](const std::string &path, const std::string &contextSourceName, + const std::string &dataName, const std::string &profilerName, + const std::string &mode) { + auto sessionId = SessionManager::instance().addSession( + path, profilerName, contextSourceName, dataName, mode); + SessionManager::instance().activateSession(sessionId); + return sessionId; + }, + pybind11::arg("path"), pybind11::arg("contextSourceName"), + pybind11::arg("dataName"), pybind11::arg("profilerName"), + pybind11::arg("mode") = ""); + + m.def("activate", [](size_t sessionId) { + SessionManager::instance().activateSession(sessionId); + }); + + m.def("activate_all", + []() { SessionManager::instance().activateAllSessions(); }); + + m.def("deactivate", [](size_t sessionId, bool flushing) { + SessionManager::instance().deactivateSession(sessionId, flushing); + }); + + m.def("deactivate_all", [](bool flushing) { + SessionManager::instance().deactivateAllSessions(flushing); + }); + + m.def("finalize", [](size_t sessionId, const std::string &outputFormat) { + SessionManager::instance().finalizeSession(sessionId, outputFormat); + }); + + m.def("finalize_all", [](const std::string &outputFormat) { + SessionManager::instance().finalizeAllSessions(outputFormat); + }); + + m.def("record_scope", []() { return Scope::getNewScopeId(); }); + + m.def("enter_scope", [](size_t scopeId, const std::string &name) { + SessionManager::instance().enterScope(Scope(scopeId, name)); + }); + + m.def("exit_scope", [](size_t scopeId, const std::string &name) { + SessionManager::instance().exitScope(Scope(scopeId, name)); + }); + + m.def("enter_op", [](size_t scopeId, const std::string &name) { + SessionManager::instance().enterOp(Scope(scopeId, name)); + }); + + m.def("exit_op", [](size_t scopeId, const std::string &name) { + SessionManager::instance().exitOp(Scope(scopeId, name)); + }); + + m.def("init_function_metadata", + [](uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdNames, + const std::vector> &scopeIdParents, + const std::string &metadataPath) { + SessionManager::instance().initFunctionMetadata( + functionId, functionName, scopeIdNames, scopeIdParents, + metadataPath); + }); + + m.def("enter_instrumented_op", [](uint64_t streamId, uint64_t functionId, + uint64_t buffer, size_t size) { + SessionManager::instance().enterInstrumentedOp( + streamId, functionId, reinterpret_cast(buffer), size); + }); + + m.def("exit_instrumented_op", [](uint64_t streamId, uint64_t functionId, + uint64_t buffer, size_t size) { + SessionManager::instance().exitInstrumentedOp( + streamId, functionId, reinterpret_cast(buffer), size); + }); + + m.def("enter_state", [](const std::string &state) { + SessionManager::instance().setState(state); + }); + + m.def("exit_state", + []() { SessionManager::instance().setState(std::nullopt); }); + + m.def( + "add_metrics", + [](size_t scopeId, + const std::map &metrics, + const std::map &tensorMetrics) { + auto convertedMetrics = convertPythonMetrics(metrics); + SessionManager::instance().addMetrics(scopeId, convertedMetrics, + tensorMetrics); + }, + pybind11::arg("scopeId"), pybind11::arg("metrics"), + pybind11::arg("tensorMetrics") = std::map()); + + m.def("set_metric_kernels", + [](uintptr_t tensorMetricKernel, uintptr_t scalarMetricKernel, + uintptr_t stream) { + SessionManager::instance().setMetricKernels( + reinterpret_cast(tensorMetricKernel), + reinterpret_cast(scalarMetricKernel), + reinterpret_cast(stream)); + }); + + m.def("get_context_depth", [](size_t sessionId) { + return SessionManager::instance().getContextDepth(sessionId); + }); + + m.def( + "get_data", + [](size_t sessionId, size_t phase) { + return SessionManager::instance().getData(sessionId, phase); + }, + pybind11::arg("sessionId"), pybind11::arg("phase")); + + m.def( + "get_data_msgpack", + [](size_t sessionId, size_t phase) { + auto data = SessionManager::instance().getDataMsgPack(sessionId, phase); + return pybind11::bytes(reinterpret_cast(data.data()), + data.size()); + }, + pybind11::arg("sessionId"), pybind11::arg("phase")); + m.def( + "clear_data", + [](size_t sessionId, size_t phase, bool clearUpToPhase) { + SessionManager::instance().clearData(sessionId, phase, clearUpToPhase); + }, + pybind11::arg("sessionId"), pybind11::arg("phase"), + pybind11::arg("clearUpToPhase") = false); + m.def( + "advance_data_phase", + [](size_t sessionId) { + return SessionManager::instance().advanceDataPhase(sessionId); + }, + pybind11::arg("sessionId")); + m.def( + "is_data_phase_complete", + [](size_t sessionId, size_t phase) { + return SessionManager::instance().isDataPhaseComplete(sessionId, phase); + }, + pybind11::arg("sessionId"), pybind11::arg("phase")); +} + +PYBIND11_MODULE(libproton, m) { + m.doc() = "Python bindings to the Proton API"; + initProton(std::move(m.def_submodule("proton"))); +} diff --git a/third_party/mthreads/proton/csrc/include/Context/Context.h b/third_party/mthreads/proton/csrc/include/Context/Context.h new file mode 100644 index 0000000000..2844d6ebd5 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Context/Context.h @@ -0,0 +1,155 @@ +#ifndef PROTON_CONTEXT_CONTEXT_H_ +#define PROTON_CONTEXT_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +/// A context is a named object. +struct Context { + std::string name{}; + + Context() = default; + Context(const std::string &name) : name(name) {} + virtual ~Context() = default; + + bool operator==(const Context &other) const { return name == other.name; } + bool operator!=(const Context &other) const { return !(*this == other); } + bool operator<(const Context &other) const { return name < other.name; } + bool operator>(const Context &other) const { return name > other.name; } + bool operator<=(const Context &other) const { return !(*this > other); } + bool operator>=(const Context &other) const { return !(*this < other); } +}; + +/// A context source is an object that can provide a list of contexts. +class ContextSource { +public: + ContextSource() = default; + virtual ~ContextSource() = default; + + std::vector getContexts() { + auto contexts = getContextsImpl(); + if (state.has_value()) { + contexts.push_back(state.value()); + } + return contexts; + } + + void setState(std::optional state) { ContextSource::state = state; } + + virtual void clear() { ContextSource::state = std::nullopt; } + + virtual size_t getDepth() = 0; + +protected: + virtual std::vector getContextsImpl() = 0; + static thread_local std::optional state; +}; + +/// A scope is a context with a unique identifier. +struct Scope : public Context { + const static size_t DummyScopeId = std::numeric_limits::max(); + static std::atomic scopeIdCounter; + + static size_t getNewScopeId() { return scopeIdCounter++; } + + size_t scopeId{}; + + explicit Scope(size_t scopeId) : Context(), scopeId(scopeId) {} + + explicit Scope(const std::string &name) : Context(name) { + scopeId = getNewScopeId(); + } + + Scope(size_t scopeId, const std::string &name) + : scopeId(scopeId), Context(name) {} + + Scope() : Scope(DummyScopeId, "") {} + + bool operator==(const Scope &other) const { + return scopeId == other.scopeId && name == other.name; + } + bool operator!=(const Scope &other) const { return !(*this == other); } + bool operator<(const Scope &other) const { + return scopeId < other.scopeId || name < other.name; + } + bool operator>(const Scope &other) const { + return scopeId > other.scopeId || name > other.name; + } + bool operator<=(const Scope &other) const { return !(*this > other); } + bool operator>=(const Scope &other) const { return !(*this < other); } +}; + +/// A scope interface allows to instrument handles before and after a scope. +/// Scopes can be nested. +class ScopeInterface { +public: + ScopeInterface() = default; + virtual ~ScopeInterface() = default; + virtual void enterScope(const Scope &scope) = 0; + virtual void exitScope(const Scope &scope) = 0; +}; + +/// An op interface allows to instrument handles before and after an operation, +/// which cannot be nested. +class OpInterface { +public: + OpInterface() = default; + virtual ~OpInterface() = default; + + void enterOp(const Scope &scope) { + if (isOpInProgress()) { + return; + } + startOp(scope); + setOpInProgress(true); + } + + void exitOp(const Scope &scope) { + if (!isOpInProgress()) { + return; + } + stopOp(scope); + setOpInProgress(false); + } + +protected: + bool isOpInProgress() { return opInProgress[this]; } + void setOpInProgress(bool value) { + opInProgress[this] = value; + if (opInProgress.size() > MAX_CACHE_OBJECTS && !value) + opInProgress.erase(this); + } + virtual void startOp(const Scope &scope) = 0; + virtual void stopOp(const Scope &scope) = 0; + +private: + inline static const int MAX_CACHE_OBJECTS = 10; + static thread_local std::map opInProgress; +}; + +class InstrumentationInterface { +public: + InstrumentationInterface() = default; + virtual ~InstrumentationInterface() = default; + + virtual void initFunctionMetadata( + uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdNames, + const std::vector> &scopeIdParents, + const std::string &metadataPath) = 0; + virtual void enterInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) = 0; + virtual void exitInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) = 0; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_CONTEXT_H_ diff --git a/third_party/mthreads/proton/csrc/include/Context/Python.h b/third_party/mthreads/proton/csrc/include/Context/Python.h new file mode 100644 index 0000000000..f58f32b766 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Context/Python.h @@ -0,0 +1,21 @@ +#ifndef PROTON_CONTEXT_PYTHON_H_ +#define PROTON_CONTEXT_PYTHON_H_ + +#include "Context.h" + +namespace proton { + +/// Unwind the Python stack and early return a list of contexts. +class PythonContextSource : public ContextSource { +public: + PythonContextSource() = default; + + size_t getDepth() override; + +private: + std::vector getContextsImpl() override; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_PYTHON_H_ diff --git a/third_party/mthreads/proton/csrc/include/Context/Shadow.h b/third_party/mthreads/proton/csrc/include/Context/Shadow.h new file mode 100644 index 0000000000..b69b5f456d --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Context/Shadow.h @@ -0,0 +1,49 @@ +#ifndef PROTON_CONTEXT_SHADOW_H_ +#define PROTON_CONTEXT_SHADOW_H_ + +#include "Context.h" +#include + +namespace proton { + +/// ShadowContextSource is designed to: +/// +/// - Maintain a main context stack for the main thread. +/// - Provide thread-local context stacks for individual threads. +/// - Allow threads to inherit and shadow the main context stack with their +/// own user-defined scopes. +/// +/// This implementation is suited for use cases like PyTorch, where: +/// +/// - The main thread initializes the main context stack during session setup. +/// - The backward phase spawns multiple CPU threads. +class ShadowContextSource : public ContextSource, public ScopeInterface { +public: + ShadowContextSource() { + mainContextStack = &threadContextStack[this]; + threadContextInitialized[this] = true; + } + + void enterScope(const Scope &scope) override; + + void exitScope(const Scope &scope) override; + + size_t getDepth() override; + + void clear() override; + +private: + std::vector getContextsImpl() override; + + void initializeThreadContext(); + + std::vector *mainContextStack{}; + static thread_local std::map + threadContextInitialized; + static thread_local std::map> + threadContextStack; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_SHADOW_H_ diff --git a/third_party/mthreads/proton/csrc/include/Data/Data.h b/third_party/mthreads/proton/csrc/include/Data/Data.h new file mode 100644 index 0000000000..aec1b6e056 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Data/Data.h @@ -0,0 +1,196 @@ +#ifndef PROTON_DATA_DATA_H_ +#define PROTON_DATA_DATA_H_ + +#include "Context/Context.h" +#include "Metric.h" +#include "PhaseStore.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +enum class OutputFormat { Hatchet, HatchetMsgPack, ChromeTrace, Count }; + +/// An "entry" is a data specific unit of operation, e.g., a node in a tree +/// data structure or an event in a trace data structure. +struct DataEntry { + /// `entryId` is a unique identifier for the entry in the data. + size_t id{Scope::DummyScopeId}; + /// `phase` indicates which phase the entry belongs to. + size_t phase{0}; + /// `metrics` is a map from metric kind to metric accumulator associated + /// with the entry. + /// Flexible metrics cannot be directly stored here since they maybe added by + /// both the frontend and the backend. + /// Use `Data::addMetrics` and `Data::addMetrics` to add flexible + /// metrics. + std::reference_wrapper>> metrics; + + explicit DataEntry(size_t id, size_t phase, + std::map> &metrics) + : id(id), phase(phase), metrics(metrics) {} + + void upsertMetric(std::unique_ptr metric) { + if (!metric) + return; + auto &metricsMap = metrics.get(); + auto it = metricsMap.find(metric->getKind()); + if (it == metricsMap.end()) { + metricsMap.emplace(metric->getKind(), std::move(metric)); + } else { + it->second->updateMetric(*metric); + } + } +}; + +class Data : public ScopeInterface { +public: + static constexpr size_t kNoCompletePhase = std::numeric_limits::max(); + + struct PhaseInfo { + size_t current{0}; + size_t completeUpTo{kNoCompletePhase}; + + bool isComplete(size_t phase) const { + return completeUpTo != kNoCompletePhase && completeUpTo >= phase; + } + }; + + Data(const std::string &path, ContextSource *contextSource = nullptr) + : path(path), contextSource(contextSource) {} + virtual ~Data() = default; + + /// Get the path associated with the data. + const std::string &getPath() const { return path; } + + /// Get the contexts associated with the data. + std::vector getContexts() const { + return contextSource->getContexts(); + } + + /// Dump the data to the given output format. + void dump(const std::string &outputFormat); + + /// Clear all non-persistent fields in the data. + /// If `clearUpToPhase` is false, clear the given phase only. + /// Otherwise, clear all phases up to and including the given phase. + void clear(size_t phase, bool clearUpToPhase = false); + + /// Advance to the next phase. + size_t advancePhase(); + + /// Mark phases up to `phase` as complete. + void completePhase(size_t phase); + + /// Atomically get current and complete phases. + PhaseInfo getPhaseInfo() const; + + /// Add an op to the data of the current phase. + /// If `opName` is empty, just use the current context as is. + /// Otherwise obtain the current context and append `opName` to it. Return the + /// entry id of the added op. + virtual DataEntry addOp(const std::string &opName = {}) = 0; + + /// Add an op with custom contexts to the data. + /// This is often used when context source is not available or when + /// the profiler itself needs to supply the contexts, such as + /// instruction samples in GPUs whose contexts are + /// synthesized from the instruction address (no unwinder). + /// + /// `phase` is the phase the op should be added to. This is important for + /// asynchronous profilers, where the current phase may have advanced by the + /// time the profiler needs to attach a child op. + virtual DataEntry addOp(size_t phase, size_t entryId, + const std::vector &contexts) = 0; + + /// Record a batch of named metrics for a scope to the data of the current + /// phase. + /// + /// This is primarily intended for user-defined metrics defined in Python and + /// directly associated with a scope. + /// `metrics` is a map from metric name to value to be applied to `scopeId`. + virtual void + addMetrics(size_t scopeId, + const std::map &metrics) = 0; + + /// Record a batch of named metrics for an entry. + /// + /// This is primarily intended for user-defined metrics defined in Python and + /// added lazily by the backend profiler. + /// `metrics` is a map from metric name to value to be applied to `entryId`. + /// + /// The same as `addOp`, `phase` is important for asynchronous profilers. + virtual void + addMetrics(size_t phase, size_t entryId, + const std::map &metrics) = 0; + + /// To Json + virtual std::string toJsonString(size_t phase) const = 0; + + /// To MsgPack + virtual std::vector toMsgPack(size_t phase) const = 0; + +protected: + /// The actual implementations + virtual void doDump(std::ostream &os, OutputFormat outputFormat, + size_t phase) const = 0; + virtual OutputFormat getDefaultOutputFormat() const = 0; + + void initPhaseStore(PhaseStoreBase &store); + + template T *currentPhasePtrAs() { + return static_cast(currentPhasePtr); + } + + template T *phasePtrAs(size_t phase) { + return static_cast(phaseStore->getPtr(phase)); + } + + [[nodiscard]] std::unique_lock + lockIfCurrentPhase(size_t phase) { + std::unique_lock lock(mutex, std::defer_lock); + const auto currentPhaseValue = currentPhase.load(std::memory_order_relaxed); + // Note that currentPhase is not locked here and can get incremented after + // this point. Correctness can still be guaranteed as no threads other than + // the profiler thread will access the data after phase advancement. + if (phase == currentPhaseValue) { + lock.lock(); + } + // Otherwise, no need to lock for other phases since they won't be updated + // by the application thread + return lock; + } + + std::atomic currentPhase{0}; + std::size_t completeUpToPhase{kNoCompletePhase}; + std::set activePhases{}; + + mutable std::shared_mutex mutex; + const std::string path{}; + ContextSource *contextSource{}; + +private: + PhaseStoreBase *phaseStore{}; + void *currentPhasePtr{}; +}; + +typedef std::map DataToEntryMap; + +OutputFormat parseOutputFormat(const std::string &outputFormat); + +const std::string outputFormatToString(OutputFormat outputFormat); + +} // namespace proton + +#endif // PROTON_DATA_DATA_H_ diff --git a/third_party/mthreads/proton/csrc/include/Data/Metric.h b/third_party/mthreads/proton/csrc/include/Data/Metric.h new file mode 100644 index 0000000000..0c8ccdbd01 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Data/Metric.h @@ -0,0 +1,528 @@ +#ifndef PROTON_DATA_METRIC_H_ +#define PROTON_DATA_METRIC_H_ + +#include "Runtime/Runtime.h" +#include "Utility/String.h" +#include "Utility/Traits.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +enum class MetricKind { Flexible, Kernel, PCSampling, Cycle, Count }; + +using MetricValueType = std::variant; + +inline const char *getTypeNameForIndex(std::size_t idx) { + switch (idx) { + case 0: + return "uint64_t"; + case 1: + return "int64_t"; + case 2: + return "double"; + case 3: + return "std::string"; + default: + return ""; + } +} + +inline const size_t getMetricValueSize(size_t index) { + switch (index) { + case 0: + return sizeof(uint64_t); + case 1: + return sizeof(int64_t); + case 2: + return sizeof(double); + case 3: + throw std::runtime_error("[PROTON] MetricValueType string size is unknown"); + default: + throw std::runtime_error("[PROTON] Unknown MetricValueType index"); + } +} + +/// A metric is a class that can be associated with a context. +/// `Metric` is the base class for all metrics. +/// Each `Metric` has a name and a set of values. +/// Each value could be of type `uint64_t`, `int64_t`, or `double`, +/// Each value can be inclusive (inc), exclusive (exc), or a property (pty). +/// Inclusive values are aggregated by addition and can be propagated to the +/// parent. +/// Exclusive values can be aggregated at a context but cannot be +/// propagated to the parent. +/// Property values are not aggregated and cannot be propagated to the parent. +class Metric { +public: + Metric(MetricKind kind, size_t size) : kind(kind), values(size) {} + + virtual ~Metric() = default; + + virtual const std::string &getName() const = 0; + + virtual const std::string &getValueName(int valueId) const = 0; + + virtual bool isProperty(int valueId) const = 0; + + virtual bool isExclusive(int valueId) const = 0; + + const std::vector &getValues() const { return values; } + + const MetricValueType &getValue(int valueId) const { return values[valueId]; } + + /// Update a specific value id with the new value. + void updateValue(int valueId, MetricValueType value) { + // Enforce type consistency: once a valueId has a type, it must not change. + if (values[valueId].index() != value.index()) { + throw std::runtime_error( + std::string("Metric value type mismatch for valueId ") + + std::to_string(valueId) + " (" + getValueName(valueId) + ")" + + ": current=" + getTypeNameForIndex(values[valueId].index()) + + ", new=" + getTypeNameForIndex(value.index())); + } + // Handle string and other values separately + if (std::holds_alternative(value)) { + values[valueId] = std::get(value); + } else { + std::visit( + [&](auto &¤tValue, auto &&otherValue) { + using CurrentType = std::decay_t; + using ValueType = std::decay_t; + if constexpr (std::is_same_v) { + if (isProperty(valueId)) { + currentValue = otherValue; + } else { + currentValue += otherValue; + } + } + }, + values[valueId], value); + } + } + + /// Update all values of the metric with the same value. + void updateValue(MetricValueType value) { + for (int i = 0; i < values.size(); ++i) { + updateValue(i, value); + } + } + + /// Update all values with another metric. + void updateMetric(const Metric &other) { + for (int i = 0; i < values.size(); ++i) { + updateValue(i, other.values[i]); + } + } + + MetricKind getKind() const { return kind; } + +private: + const MetricKind kind; + +protected: + std::vector values; +}; + +/// A flexible metric is provided by users but not the backend profiling API. +/// Each flexible metric has a single value. +class FlexibleMetric : public Metric { +public: + FlexibleMetric(const std::string &valueName, + std::variant value) + : Metric(MetricKind::Flexible, 1), valueName(valueName) { + this->exclusive = endWith(valueName, "(exc)"); + this->property = endWith(valueName, "(pty)"); + this->valueName = trim(replace(this->valueName, "(exc)", "")); + this->valueName = trim(replace(this->valueName, "(pty)", "")); + std::visit([&](auto &&v) { this->values[0] = v; }, value); + } + + FlexibleMetric(const std::string &valueName, + std::variant value, bool property, + bool exclusive) + : Metric(MetricKind::Flexible, 1), valueName(valueName), + property(property), exclusive(exclusive) { + std::visit([&](auto &&v) { this->values[0] = v; }, value); + } + + const std::string &getName() const override { return name; } + + const std::string &getValueName(int valueId) const override { + return valueName; + } + + bool isProperty(int valueId) const override { return property; } + + bool isExclusive(int valueId) const override { return exclusive; } + +private: + bool property{}; + bool exclusive{}; + const static inline std::string name = "FlexibleMetric"; + std::string valueName; +}; + +class KernelMetric : public Metric { +public: + enum kernelMetricKind : int { + StartTime, + EndTime, + Invocations, + Duration, + DeviceId, + DeviceType, + StreamId, + Count, + }; + + KernelMetric() : Metric(MetricKind::Kernel, kernelMetricKind::Count) {} + + KernelMetric(uint64_t startTime, uint64_t endTime, uint64_t invocations, + uint64_t deviceId, uint64_t deviceType, uint64_t streamId) + : KernelMetric() { + this->values[StartTime] = startTime; + this->values[EndTime] = endTime; + this->values[Invocations] = invocations; + this->values[Duration] = endTime - startTime; + this->values[DeviceId] = deviceId; + this->values[DeviceType] = deviceType; + this->values[StreamId] = streamId; + } + + const std::string &getName() const override { return name; } + + const std::string &getValueName(int valueId) const override { + return VALUE_NAMES[valueId]; + } + + bool isProperty(int valueId) const override { return PROPERTY[valueId]; } + + bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; } + +private: + const static inline bool PROPERTY[kernelMetricKind::Count] = { + true, true, false, false, true, true, true}; + const static inline bool EXCLUSIVE[kernelMetricKind::Count] = { + false, false, false, false, true, true, true}; + const static inline std::string VALUE_NAMES[kernelMetricKind::Count] = { + "start_time (ns)", "end_time (ns)", "count", "time (ns)", + "device_id", "device_type", "stream_id", + }; + const static inline std::string name = "KernelMetric"; +}; + +class PCSamplingMetric : public Metric { +public: + enum PCSamplingMetricKind : int { + NumSamples, + NumStalledSamples, + StalledBranchResolving, + StalledNoInstruction, + StalledShortScoreboard, + StalledWait, + StalledLongScoreboard, + StalledTexThrottle, + StalledBarrier, + StalledMembar, + StalledIMCMiss, + StalledMIOThrottle, + StalledMathPipeThrottle, + StalledDrain, + StalledLGThrottle, + StalledNotSelected, + StalledMisc, + StalledDispatchStall, + StalledSleeping, + StalledSelected, + Count, + }; + + PCSamplingMetric() + : Metric(MetricKind::PCSampling, PCSamplingMetricKind::Count) {} + + PCSamplingMetric(PCSamplingMetricKind kind, uint64_t samples, + uint64_t stalledSamples) + : PCSamplingMetric() { + this->values[kind] = stalledSamples; + this->values[PCSamplingMetricKind::NumSamples] = samples; + this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples; + } + + const std::string &getName() const override { return name; } + + const std::string &getValueName(int valueId) const override { + return VALUE_NAMES[valueId]; + } + + bool isProperty(int valueId) const override { return false; } + bool isExclusive(int valueId) const override { return false; } + + const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = { + "num_samples", + "num_stalled_samples", + "stalled_branch_resolving", + "stalled_no_instruction", + "stalled_short_scoreboard", + "stalled_wait", + "stalled_long_scoreboard", + "stalled_tex_throttle", + "stalled_barrier", + "stalled_membar", + "stalled_imc_miss", + "stalled_mio_throttle", + "stalled_math_pipe_throttle", + "stalled_drain", + "stalled_lg_throttle", + "stalled_not_Selected", + "stalled_misc", + "stalled_dispatch_stall", + "stalled_sleeping", + "stalled_selected", + }; + const static inline std::string name = "PCSamplingMetric"; +}; + +class CycleMetric : public Metric { +public: + enum CycleMetricKind : int { + StartCycle, + EndCycle, + Duration, + NormalizedDuration, + KernelId, + KernelName, + BlockId, + ProcessorId, + UnitId, + DeviceId, + DeviceType, + TimeShiftCost, + InitTime, + PreFinalTime, + PostFinalTime, + Count, + }; + + CycleMetric() : Metric(MetricKind::Cycle, CycleMetricKind::Count) {} + + CycleMetric(uint64_t startCycle, uint64_t endCycle, uint64_t duration, + double normalizedDuration, uint64_t kernelId, + const std::string &kernelName, uint64_t blockId, + uint64_t processorId, uint64_t unitId, uint64_t deviceId, + uint64_t deviceType, uint64_t timeShiftCost, uint64_t initTime, + uint64_t preFinalTime, uint64_t postFinalTime) + : CycleMetric() { + this->values[StartCycle] = startCycle; + this->values[EndCycle] = endCycle; + this->values[Duration] = duration; + this->values[NormalizedDuration] = normalizedDuration; + this->values[KernelId] = kernelId; + this->values[KernelName] = kernelName; + this->values[BlockId] = blockId; + this->values[ProcessorId] = processorId; + this->values[UnitId] = unitId; + this->values[DeviceId] = deviceId; + this->values[DeviceType] = deviceType; + this->values[TimeShiftCost] = timeShiftCost; + this->values[InitTime] = initTime; + this->values[PreFinalTime] = preFinalTime; + this->values[PostFinalTime] = postFinalTime; + } + + const std::string &getName() const override { return name; } + + const std::string &getValueName(int valueId) const override { + return VALUE_NAMES[valueId]; + } + + bool isProperty(int valueId) const override { return PROPERTY[valueId]; } + + bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; } + +private: + const static inline bool PROPERTY[CycleMetricKind::Count] = { + false, false, false, false, true, true, true, true, + true, true, true, true, false, false, false}; + const static inline bool EXCLUSIVE[CycleMetricKind::Count] = { + false, false, true, true, true, true, true, true, + true, true, true, true, false, false, false}; + const static inline std::string VALUE_NAMES[CycleMetricKind::Count] = { + "start_cycle", "end_cycle", "cycles", "normalized_cycles", + "kernel_id", "kernel_name", "block_id", "processor_id", + "unit_id", "device_id", "device_type", "time_shift_cost", + "init_time", "pre_final_time", "post_final_time"}; + const static inline std::string name = "CycleMetric"; +}; + +/// Each TensorMetric represents a scalar metric stored in a device buffer. +struct TensorMetric { + uint8_t *ptr{}; // device pointer + size_t index{}; // MetricValueType index +}; + +/// Collect tensor metrics from device to host. +std::map +collectTensorMetrics(Runtime *runtime, + const std::map &tensorMetrics, + void *stream); + +/// A MetricBuffer stores tensor metrics generated by GPU kernels. +/// The synchronization behaviors are handled by the runtime of the device. +/// A kernel can be associated with multiple tensor metrics but we do not +/// store the association on the device side. +/// +/// Here's the layout of the buffer and it's meta data that are maintained on +/// the host: +/// +/// host -> -------- kernel0 -------- +/// / \ +/// [device0] -> metric buffer -> {metric_id, value, metric_id, value, ...} +/// | /|\ +/// | | +/// | deviceOffsetPtr ------------- +/// | devicePtr +class MetricBuffer { +public: + struct MetricDescriptor { + size_t id{}; + size_t typeIndex{}; + std::string name{}; + }; + +public: + MetricBuffer(size_t capacity, Runtime *runtime, bool mappedHostBuffer = true) + : capacity(capacity), runtime(runtime), + mappedHostBuffer(mappedHostBuffer) {} + + ~MetricBuffer(); + + void receive(const std::map &scalarMetrics, + const std::map &tensorMetrics, + void *tensorMetricKernel, void *scalarMetricKernel, + void *stream); + + void reserve() { getOrCreateBuffer(); } + + Runtime *getRuntime() const { return runtime; } + + // no sync flush + template void peek(Device *device, Func callback) { + std::lock_guard lock(bufferMutex); + auto it = deviceBuffers.find(device); + if (it != deviceBuffers.end()) { + auto &buffer = it->second; + callback(buffer.hostPtr); + } + } + + template void flush(Func callback, bool flushAll = false) { + std::vector> buffersToFlush; + if (flushAll) { + std::lock_guard lock(bufferMutex); + for (auto &[device, buffer] : deviceBuffers) { + buffersToFlush.emplace_back(device, buffer); + } + } else { + buffersToFlush.emplace_back(runtime->getDevice(), getOrCreateBuffer()); + } + for (auto &[device, buffer] : buffersToFlush) { + synchronize(buffer); + callback(device, buffer.hostPtr); + } + } + + size_t getCapacity() const { return capacity; } + + MetricDescriptor &getMetricDescriptor(size_t id) { + std::shared_lock lock(metricDescriptorMutex); + auto it = metricDescriptors.find(id); + if (it == metricDescriptors.end()) { + throw std::runtime_error("[PROTON] MetricBuffer: unknown metric id: " + + std::to_string(id)); + } + return it->second; + } + +private: + struct DeviceBuffer { + uint8_t *devicePtr{}; + uint8_t *deviceOffsetPtr{}; + uint8_t *hostPtr{}; + uint64_t *hostOffset{}; + void *priorityStream{}; + }; + + DeviceBuffer &getOrCreateBuffer(); + + void queue(size_t metricId, TensorMetric tensorMetric, void *kernel, + void *stream); + + void queue(size_t metricId, MetricValueType scalarMetric, void *kernel, + void *stream); + + void synchronize(DeviceBuffer &buffer); + + template + size_t getMetricIndex(const MetricT &metric) const { + using MetricType = std::decay_t; + if constexpr (std::is_same_v) { + return metric.index(); + } else if constexpr (std::is_same_v) { + return metric.index; + } else { + static_assert(always_false::value, + "Unsupported metric type for getMetricIndex"); + } + } + + template + void queueMetrics(const MetricsT &metrics, void *kernel, void *stream) { + for (const auto &[name, metric] : metrics) { + size_t index = getMetricIndex(metric); + auto descriptor = getOrCreateMetricDescriptor(name, index); + queue(descriptor.id, metric, kernel, stream); + } + } + + MetricDescriptor getOrCreateMetricDescriptor(const std::string &name, + size_t typeIndex); + +protected: + static std::atomic metricId; + static std::map metricDescriptors; + static std::map metricNameToId; + static std::shared_mutex metricDescriptorMutex; + + size_t capacity; // byte + Runtime *runtime{}; + const bool mappedHostBuffer{true}; + + std::map deviceBuffers; + std::mutex bufferMutex; +}; + +class MetricInterface { +public: + virtual ~MetricInterface() = default; + + virtual void + addMetrics(size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) = 0; + + virtual void setMetricKernels(void *tensorMetricKernel, + void *scalarMetricKernel, void *stream) = 0; +}; + +} // namespace proton + +#endif // PROTON_DATA_METRIC_H_ diff --git a/third_party/mthreads/proton/csrc/include/Data/PhaseStore.h b/third_party/mthreads/proton/csrc/include/Data/PhaseStore.h new file mode 100644 index 0000000000..d79b3801fc --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Data/PhaseStore.h @@ -0,0 +1,111 @@ +#ifndef PROTON_DATA_PHASE_STORE_H_ +#define PROTON_DATA_PHASE_STORE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +class PhaseStoreBase { +public: + virtual ~PhaseStoreBase() = default; + + virtual void *getPtr(size_t phase) = 0; + virtual void *createPtr(size_t phase) = 0; + virtual void clearUpToInclusive(size_t phase) = 0; + virtual void clearPhase(size_t phase) = 0; +}; + +template class PhaseStore final : public PhaseStoreBase { +public: + PhaseStore() = default; + ~PhaseStore() override = default; + + struct Slot { + mutable std::shared_mutex mutex; + std::unique_ptr value; + }; + + void *createPtr(size_t phase) override { + std::shared_ptr slot; + { + std::unique_lock lock(phasesMutex); + auto &entry = phases[phase]; + if (!entry) + entry = std::make_shared(); + slot = entry; + } + { + std::unique_lock slotLock(slot->mutex); + if (!slot->value) // slot value might not exist yet or been cleared + slot->value = std::make_unique(); + return slot->value.get(); + } + } + + void *getPtr(size_t phase) override { return getSlot(phase)->value.get(); } + + void clearUpToInclusive(size_t phase) override { + clearRangeInclusive(0, phase); + } + + void clearPhase(size_t phase) override { clearRangeInclusive(phase, phase); } + + template decltype(auto) withPtr(size_t phase, FnT &&fn) const { + auto slot = getSlot(phase); + std::shared_lock slotLock(slot->mutex); + return std::forward(fn)(slot->value.get()); + } + +private: + void clearRangeInclusive(size_t beginPhase, size_t endPhase) { + std::vector> slotsToClear; + { + std::shared_lock lock(phasesMutex); + auto it = phases.lower_bound(beginPhase); + auto endIt = phases.upper_bound(endPhase); + for (; it != endIt; ++it) { + if (it->second) { + slotsToClear.push_back(it->second); + } + } + } + + // Free the heavy per-phase payloads under per-phase locks, without blocking + // unrelated phases from being accessed via the store map. + for (auto &slot : slotsToClear) { + std::unique_lock slotLock(slot->mutex); + slot->value.reset(); + } + + // Finally, prune the cleared phases from the map. + { + std::unique_lock lock(phasesMutex); + phases.erase(phases.lower_bound(beginPhase), + phases.upper_bound(endPhase)); + } + } + + std::shared_ptr getSlot(size_t phase) const { + std::shared_lock lock(phasesMutex); + auto it = phases.find(phase); + if (it == phases.end() || !it->second) { + throw std::runtime_error("[PROTON] Phase " + std::to_string(phase) + + " has no data."); + } + return it->second; + } + + mutable std::shared_mutex phasesMutex; + std::map> phases; +}; + +} // namespace proton + +#endif // PROTON_DATA_PHASE_STORE_H_ diff --git a/third_party/mthreads/proton/csrc/include/Data/TraceData.h b/third_party/mthreads/proton/csrc/include/Data/TraceData.h new file mode 100644 index 0000000000..6877ddd2d2 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Data/TraceData.h @@ -0,0 +1,58 @@ +#ifndef PROTON_DATA_TRACE_DATA_H_ +#define PROTON_DATA_TRACE_DATA_H_ + +#include "Data.h" +#include +#include + +namespace proton { + +class TraceData : public Data { +public: + TraceData(const std::string &path, ContextSource *contextSource = nullptr); + virtual ~TraceData(); + + std::string toJsonString(size_t phase) const override; + + std::vector toMsgPack(size_t phase) const override; + + DataEntry addOp(const std::string &name) override; + + DataEntry addOp(size_t phase, size_t eventId, + const std::vector &contexts) override; + + void + addMetrics(size_t scopeId, + const std::map &metrics) override; + + void + addMetrics(size_t phase, size_t entryId, + const std::map &metrics) override; + + class Trace; + +protected: + // ScopeInterface + void enterScope(const Scope &scope) override final; + + void exitScope(const Scope &scope) override final; + +private: + // Data + void doDump(std::ostream &os, OutputFormat outputFormat, + size_t phase) const override; + + OutputFormat getDefaultOutputFormat() const override { + return OutputFormat::ChromeTrace; + } + + void dumpChromeTrace(std::ostream &os, size_t phase) const; + + PhaseStore tracePhases; + // ScopeId -> EventId + std::unordered_map scopeIdToEventId; +}; + +} // namespace proton + +#endif // PROTON_DATA_TRACE_DATA_H_ diff --git a/third_party/mthreads/proton/csrc/include/Data/TreeData.h b/third_party/mthreads/proton/csrc/include/Data/TreeData.h new file mode 100644 index 0000000000..1b1745f1bd --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Data/TreeData.h @@ -0,0 +1,72 @@ +#ifndef PROTON_DATA_TREE_DATA_H_ +#define PROTON_DATA_TREE_DATA_H_ + +#include "Context/Context.h" +#include "Data.h" +#include "nlohmann/json.hpp" +#include +#include +#include +#include + +using json = nlohmann::json; + +namespace proton { + +class TreeData : public Data { +public: + TreeData(const std::string &path, ContextSource *contextSource); + virtual ~TreeData(); + + TreeData(const std::string &path) : TreeData(path, nullptr) {} + + std::string toJsonString(size_t phase) const override; + + std::vector toMsgPack(size_t phase) const override; + + DataEntry addOp(const std::string &name) override; + + DataEntry addOp(size_t phase, size_t contextId, + const std::vector &contexts) override; + + void + addMetrics(size_t scopeId, + const std::map &metrics) override; + + void + addMetrics(size_t phase, size_t entryId, + const std::map &metrics) override; + +protected: + // ScopeInterface + void enterScope(const Scope &scope) override; + + void exitScope(const Scope &scope) override; + +private: + // `tree` and `scopeIdToContextId` can be accessed by both the user thread and + // the background threads concurrently, so methods that access them should be + // protected by a (shared) mutex. + class Tree; + json buildHatchetJson(TreeData::Tree *tree) const; + std::vector buildHatchetMsgPack(TreeData::Tree *tree) const; + + // Data + void doDump(std::ostream &os, OutputFormat outputFormat, + size_t phase) const override; + + OutputFormat getDefaultOutputFormat() const override { + return OutputFormat::Hatchet; + } + + void dumpHatchet(std::ostream &os, size_t phase) const; + void dumpHatchetMsgPack(std::ostream &os, size_t phase) const; + + PhaseStore treePhases; + // ScopeId -> ContextId + std::unordered_map scopeIdToContextId; +}; + +} // namespace proton + +#endif // PROTON_DATA_TREE_DATA_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/Dispatch.h b/third_party/mthreads/proton/csrc/include/Driver/Dispatch.h new file mode 100644 index 0000000000..920151302d --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/Dispatch.h @@ -0,0 +1,168 @@ +#ifndef PROTON_DRIVER_DISPATCH_H_ +#define PROTON_DRIVER_DISPATCH_H_ + +#include + +#include "Utility/Env.h" +#include +#include + +#define DISPATCH_ARGS_0() +#define DISPATCH_ARGS_1(t1) t1 v1 +#define DISPATCH_ARGS_2(t1, t2) t1 v1, t2 v2 +#define DISPATCH_ARGS_3(t1, t2, t3) t1 v1, t2 v2, t3 v3 +#define DISPATCH_ARGS_4(t1, t2, t3, t4) t1 v1, t2 v2, t3 v3, t4 v4 +#define DISPATCH_ARGS_5(t1, t2, t3, t4, t5) t1 v1, t2 v2, t3 v3, t4 v4, t5 v5 +#define DISPATCH_ARGS_6(t1, t2, t3, t4, t5, t6) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6 +#define DISPATCH_ARGS_7(t1, t2, t3, t4, t5, t6, t7) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7 +#define DISPATCH_ARGS_8(t1, t2, t3, t4, t5, t6, t7, t8) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7, t8 v8 +#define DISPATCH_ARGS_9(t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7, t8 v8, t9 v9 +#define DISPATCH_ARGS_10(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7, t8 v8, t9 v9, t10 v10 +#define DISPATCH_ARGS_11(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7, t8 v8, t9 v9, t10 v10, \ + t11 v11 +#define DISPATCH_ARGS_12(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12) \ + t1 v1, t2 v2, t3 v3, t4 v4, t5 v5, t6 v6, t7 v7, t8 v8, t9 v9, t10 v10, \ + t11 v11, t12 v12 +#define DISPATCH_ARGS_N(_12, _11, _10, _9, _8, _7, _6, _5, _4, _3, _2, _1, _0, \ + N, ...) \ + DISPATCH_ARGS##N +#define DISPATCH_ARGS(...) \ + DISPATCH_ARGS_N(_0, ##__VA_ARGS__, _12, _11, _10, _9, _8, _7, _6, _5, _4, \ + _3, _2, _1, _0)(__VA_ARGS__) + +#define DISPATCH_VALS_0() +#define DISPATCH_VALS_1(t1) , v1 +#define DISPATCH_VALS_2(t1, t2) , v1, v2 +#define DISPATCH_VALS_3(t1, t2, t3) , v1, v2, v3 +#define DISPATCH_VALS_4(t1, t2, t3, t4) , v1, v2, v3, v4 +#define DISPATCH_VALS_5(t1, t2, t3, t4, t5) , v1, v2, v3, v4, v5 +#define DISPATCH_VALS_6(t1, t2, t3, t4, t5, t6) , v1, v2, v3, v4, v5, v6 +#define DISPATCH_VALS_7(t1, t2, t3, t4, t5, t6, t7) , v1, v2, v3, v4, v5, v6, v7 +#define DISPATCH_VALS_8(t1, t2, t3, t4, t5, t6, t7, t8) \ + , v1, v2, v3, v4, v5, v6, v7, v8 +#define DISPATCH_VALS_9(t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + , v1, v2, v3, v4, v5, v6, v7, v8, v9 +#define DISPATCH_VALS_10(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) \ + , v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define DISPATCH_VALS_11(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \ + , v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define DISPATCH_VALS_12(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12) \ + , v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define DISPATCH_VALS_N(_12, _11, _10, _9, _8, _7, _6, _5, _4, _3, _2, _1, _0, \ + N, ...) \ + DISPATCH_VALS##N +#define DISPATCH_VALS(...) \ + DISPATCH_VALS_N(_0, ##__VA_ARGS__, _12, _11, _10, _9, _8, _7, _6, _5, _4, \ + _3, _2, _1, _0)(__VA_ARGS__) + +#define DEFINE_DISPATCH_TEMPLATE(CheckSuccess, FuncName, ExternLib, FuncType, \ + ...) \ + template <> \ + ExternLib::RetType FuncName(DISPATCH_ARGS(__VA_ARGS__)) { \ + typedef typename ExternLib::RetType (*FuncType##_t)(__VA_ARGS__); \ + static FuncType##_t func = nullptr; \ + return Dispatch::exec( \ + func, #FuncType DISPATCH_VALS(__VA_ARGS__)); \ + } + +#define DEFINE_DISPATCH(ExternLib, FuncName, FuncType, ...) \ + DEFINE_DISPATCH_TEMPLATE(true, FuncName, ExternLib, FuncType, __VA_ARGS__) \ + DEFINE_DISPATCH_TEMPLATE(false, FuncName, ExternLib, FuncType, __VA_ARGS__) + +namespace proton { + +struct ExternLibBase { + using RetType = int; // Generic type, can be overridden in derived structs + static constexpr const char *name = ""; // Placeholder + static constexpr const char *symbolName{}; // Placeholder + static constexpr const char *pathEnv{}; // Placeholder + static constexpr RetType success = 0; // Placeholder + ExternLibBase() = delete; + ExternLibBase(const ExternLibBase &) = delete; + ExternLibBase &operator=(const ExternLibBase &) = delete; + static inline void *lib{nullptr}; +}; + +template class Dispatch { +public: + Dispatch() = delete; + + static void init(const char *name, void **lib) { + if (*lib == nullptr) { + // If not found, try to load it from the default path + auto dir = + ExternLib::pathEnv == nullptr ? "" : getStrEnv(ExternLib::pathEnv); + if (!dir.empty()) { + auto fullPath = dir + "/" + name; + *lib = dlopen(fullPath.c_str(), RTLD_LOCAL | RTLD_LAZY); + } else { + // Only if the default path is not set, we try to load it from the + // system. + // First reuse the existing handle + *lib = dlopen(name, RTLD_NOLOAD); + if (*lib == nullptr) { + // If not found, try to load it from LD_LIBRARY_PATH + *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } + } + } + if (*lib == nullptr) { + throw std::runtime_error("Could not load `" + std::string(name) + "`"); + } + } + + static void check(typename ExternLib::RetType ret, const char *functionName) { + if (ret != ExternLib::success) { + throw std::runtime_error("Failed to execute " + + std::string(functionName) + " with error " + + std::to_string(ret)); + } + } + + template + static typename ExternLib::RetType + exec(FnT &handler, const char *functionName, Args... args) { + init(ExternLib::name, &ExternLib::lib); + if (handler == nullptr) { + handler = reinterpret_cast(dlsym(ExternLib::lib, functionName)); + if (handler == nullptr) { + throw std::runtime_error("Failed to load " + + std::string(ExternLib::name)); + } + } + auto ret = handler(args...); + if constexpr (CheckSuccess) { + check(ret, functionName); + } + return ret; + } + + static std::string getLibPath() { + if (ExternLib::lib == nullptr) { + // Force initialization + Dispatch::init(ExternLib::name, &ExternLib::lib); + if (ExternLib::lib == nullptr) { + return ""; + } + } + if (ExternLib::lib != nullptr) { + void *sym = dlsym(ExternLib::lib, + ExternLib::symbolName); // pick any known symbol + Dl_info info; + if (dladdr(sym, &info)) { + return info.dli_fname; + } + } + return ""; + } +}; + +} // namespace proton + +#endif // PROTON_DRIVER_DISPATCH_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/CudaApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/CudaApi.h new file mode 100644 index 0000000000..0e778300f1 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/CudaApi.h @@ -0,0 +1,72 @@ +#ifndef PROTON_DRIVER_GPU_CUDA_API_H_ +#define PROTON_DRIVER_GPU_CUDA_API_H_ + +#include "Device.h" +#include "cuda.h" + +namespace proton { + +namespace cuda { + +template CUresult init(int flags); + +template CUresult ctxSynchronize(); + +template CUresult ctxGetCurrent(CUcontext *pctx); + +template CUresult ctxGetDevice(CUdevice *device); + +template +CUresult ctxGetStreamPriorityRange(int *leastPriority, int *greatestPriority); + +template +CUresult deviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); + +template CUresult deviceGet(CUdevice *device, int ordinal); + +template +CUresult streamCreateWithPriority(CUstream *pStream, unsigned int flags, + int priority); + +template CUresult streamSynchronize(CUstream stream); + +template CUresult streamDestroy(CUstream stream); + +template +CUresult memcpyDToHAsync(void *dst, CUdeviceptr src, size_t count, + CUstream stream); + +template +CUresult memsetD32Async(CUdeviceptr dst, unsigned int ui, size_t N, + CUstream stream); + +template +CUresult memAlloc(CUdeviceptr *dptr, size_t bytesize); + +template CUresult memFree(CUdeviceptr dptr); + +template CUresult memAllocHost(void **pp, size_t bytesize); + +template +CUresult memHostAlloc(void **pp, size_t bytesize, unsigned int flags); + +template +CUresult memHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int flags); + +template CUresult memFreeHost(void *p); + +template +CUresult launchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void **kernelParams, void **extra); + +Device getDevice(uint64_t index); + +} // namespace cuda + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_CUDA_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/CuptiApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/CuptiApi.h new file mode 100644 index 0000000000..5e62ae7028 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/CuptiApi.h @@ -0,0 +1,129 @@ +#ifndef PROTON_DRIVER_GPU_CUPTI_API_H_ +#define PROTON_DRIVER_GPU_CUPTI_API_H_ + +#include "Driver/Dispatch.h" +#include "cupti.h" +#include "cupti_pcsampling.h" +#include + +namespace proton { + +namespace cupti { + +struct ExternLibCupti : public ExternLibBase { + using RetType = CUptiResult; + static constexpr const char *name = "libcupti.so"; + static constexpr const char *symbolName = "cuptiUnsubscribe"; + static constexpr const char *pathEnv = "TRITON_CUPTI_LIB_PATH"; + static constexpr RetType success = CUPTI_SUCCESS; + static inline void *lib = nullptr; +}; + +template CUptiResult getVersion(uint32_t *version); + +template +CUptiResult getContextId(CUcontext context, uint32_t *pCtxId); + +template +CUptiResult activityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc funcBufferRequested, + CUpti_BuffersCallbackCompleteFunc funcBufferCompleted); + +template +CUptiResult subscribe(CUpti_SubscriberHandle *subscriber, + CUpti_CallbackFunc callback, void *userdata); + +template +CUptiResult enableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain); + +template +CUptiResult enableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, CUpti_CallbackId cbid); + +template +CUptiResult activityEnableContext(CUcontext context, CUpti_ActivityKind kind); + +template +CUptiResult activityDisableContext(CUcontext context, CUpti_ActivityKind kind); + +template +CUptiResult activityEnable(CUpti_ActivityKind kind); + +template +CUptiResult activityDisable(CUpti_ActivityKind kind); + +template CUptiResult activityFlushAll(uint32_t flag); + +template +CUptiResult activityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, + CUpti_Activity **record); + +template +CUptiResult +activityPushExternalCorrelationId(CUpti_ExternalCorrelationKind kind, + uint64_t id); + +template +CUptiResult activityPopExternalCorrelationId(CUpti_ExternalCorrelationKind kind, + uint64_t *lastId); + +template +CUptiResult activitySetAttribute(CUpti_ActivityAttribute attr, + size_t *valueSize, void *value); + +template CUptiResult activityEnableHWTrace(uint8_t enable); + +template +CUptiResult unsubscribe(CUpti_SubscriberHandle subscriber); + +template CUptiResult finalize(); + +template +CUptiResult getGraphExecId(CUgraphExec graph, uint32_t *pId); + +template +CUptiResult getGraphId(CUgraph graph, uint32_t *pId); + +template +CUptiResult getGraphNodeId(CUgraphNode node, uint64_t *pId); + +template +CUptiResult getCubinCrc(CUpti_GetCubinCrcParams *pParams); + +template +CUptiResult +getSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams); + +template +CUptiResult +pcSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams); + +template +CUptiResult +pcSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams); + +template +CUptiResult pcSamplingSetConfigurationAttribute( + CUpti_PCSamplingConfigurationInfoParams *pParams); + +template +CUptiResult pcSamplingEnable(CUpti_PCSamplingEnableParams *pParams); + +template +CUptiResult pcSamplingDisable(CUpti_PCSamplingDisableParams *pParams); + +template +CUptiResult pcSamplingGetData(CUpti_PCSamplingGetDataParams *pParams); + +template +CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams); + +template +CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams); + +} // namespace cupti + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_CUPTI_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/HipApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/HipApi.h new file mode 100644 index 0000000000..12af24a30e --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/HipApi.h @@ -0,0 +1,79 @@ +#ifndef PROTON_DRIVER_GPU_HIP_API_H_ +#define PROTON_DRIVER_GPU_HIP_API_H_ + +#include "Device.h" +#include "hip/hip_runtime_api.h" + +namespace proton { + +namespace hip { + +template +hipError_t launchKernel(hipFunction_t f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, + hipStream_t stream, void **kernelParams, void **extra); + +template hipError_t ctxGetDevice(hipDevice_t *device); + +template hipError_t deviceSynchronize(); + +template +hipError_t deviceGetAttribute(int *value, hipDeviceAttribute_t attribute, + int deviceId); + +template hipError_t getDeviceCount(int *count); + +template +hipError_t getDeviceProperties(hipDeviceProp_t *prop, int deviceId); + +Device getDevice(uint64_t index); + +template +hipError_t ctxGetStreamPriorityRange(int *leastPriority, int *greatestPriority); + +template +hipError_t streamCreateWithPriority(hipStream_t *pStream, unsigned int flags, + int priority); + +template hipError_t streamSynchronize(hipStream_t stream); + +template hipError_t streamDestroy(hipStream_t stream); + +template +hipError_t memcpyDToHAsync(void *dst, hipDeviceptr_t src, size_t count, + hipStream_t stream); + +template +hipError_t memsetD32Async(hipDeviceptr_t dst, int value, size_t count, + hipStream_t stream); + +template +hipError_t memAlloc(hipDeviceptr_t *dptr, size_t bytesize); + +template hipError_t memFree(hipDeviceptr_t dptr); + +const std::string getHipArchName(uint64_t index); + +const char *getKernelNameRef(const hipFunction_t f); + +const char *getKernelNameRefByPtr(const void *hostFunction, hipStream_t stream); + +template +hipError_t memAllocHost(void **pp, size_t bytesize); + +template +hipError_t memHostAlloc(void **pp, size_t bytesize, unsigned int flags); + +template +hipError_t memHostGetDevicePointer(hipDeviceptr_t *pdptr, void *p, + unsigned int flags); + +template hipError_t memFreeHost(void *p); + +} // namespace hip + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_HIP_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/HsaApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/HsaApi.h new file mode 100644 index 0000000000..f7f4abc26a --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/HsaApi.h @@ -0,0 +1,23 @@ +#ifndef PROTON_DRIVER_GPU_HSA_API_H_ +#define PROTON_DRIVER_GPU_HSA_API_H_ + +#include "Device.h" +#include "hsa/hsa_ext_amd.h" + +namespace proton { + +namespace hsa { + +template +hsa_status_t agentGetInfo(hsa_agent_t agent, hsa_agent_info_t attribute, + void *value); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data); + +} // namespace hsa + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_HSA_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/NvtxApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/NvtxApi.h new file mode 100644 index 0000000000..68ace94b3a --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/NvtxApi.h @@ -0,0 +1,20 @@ +#ifndef PROTON_DRIVER_GPU_NVTX_API_H_ +#define PROTON_DRIVER_GPU_NVTX_API_H_ + +#include + +namespace proton { + +namespace nvtx { + +void enable(); + +void disable(); + +std::string getMessageFromRangePushA(const void *params); + +} // namespace nvtx + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_NVTX_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Driver/GPU/RoctracerApi.h b/third_party/mthreads/proton/csrc/include/Driver/GPU/RoctracerApi.h new file mode 100644 index 0000000000..5b881fee46 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Driver/GPU/RoctracerApi.h @@ -0,0 +1,95 @@ +#ifndef PROTON_DRIVER_GPU_ROCTRACER_API_H_ +#define PROTON_DRIVER_GPU_ROCTRACER_API_H_ + +#include "Driver/Dispatch.h" +#include "roctracer/roctracer.h" + +namespace proton { + +namespace roctracer { + +struct ExternLibRoctracer : public ExternLibBase { + using RetType = roctracer_status_t; + static constexpr const char *name = "libroctracer64.so"; + static constexpr const char *symbolName = "roctracer_start"; + static constexpr const char *pathEnv{}; + static constexpr RetType success = ROCTRACER_STATUS_SUCCESS; + static inline void *lib = nullptr; +}; + +template +roctracer_status_t setProperties(roctracer_domain_t domain, void *properties); + +template +roctracer_status_t getTimestamp(roctracer_timestamp_t *timestamp); + +void start(); + +void stop(); + +// +// Callbacks +// + +template +roctracer_status_t enableDomainCallback(activity_domain_t domain, + activity_rtapi_callback_t callback, + void *arg); + +template +roctracer_status_t disableDomainCallback(activity_domain_t domain); + +template +roctracer_status_t enableOpCallback(activity_domain_t domain, uint32_t op, + activity_rtapi_callback_t callback, + void *arg); + +template +roctracer_status_t disableOpCallback(activity_domain_t domain, uint32_t op); + +// +// Activity +// + +template +roctracer_status_t openPool(const roctracer_properties_t *properties); + +template roctracer_status_t closePool(); + +template +roctracer_status_t enableOpActivity(activity_domain_t domain, uint32_t op); + +template +roctracer_status_t enableDomainActivity(activity_domain_t domain); + +template +roctracer_status_t disableOpActivity(activity_domain_t domain, uint32_t op); + +template +roctracer_status_t disableDomainActivity(activity_domain_t domain); + +template roctracer_status_t flushActivity(); + +template +roctracer_status_t getNextRecord(const activity_record_t *record, + const activity_record_t **next); + +char *getOpString(uint32_t domain, uint32_t op, uint32_t kind); + +// +// External correlation +// + +template +roctracer_status_t +activityPushExternalCorrelationId(activity_correlation_id_t id); + +template +roctracer_status_t +activityPopExternalCorrelationId(activity_correlation_id_t *last_id); + +} // namespace roctracer + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_ROCTRACER_API_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h b/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h new file mode 100644 index 0000000000..bc9cb26733 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h @@ -0,0 +1,142 @@ +#ifndef PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ +#define PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ + +#include "CuptiProfiler.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Map.h" +#include "Utility/Set.h" +#include "Utility/Singleton.h" +#include +#include + +namespace proton { + +struct CubinData { + size_t cubinCrc; + const char *cubin; + size_t cubinSize; + + struct LineInfoKey { + uint32_t functionIndex; + uint64_t pcOffset; + + bool operator<(const LineInfoKey &other) const { + return functionIndex < other.functionIndex || + (functionIndex == other.functionIndex && + pcOffset < other.pcOffset); + } + }; + + struct LineInfoValue { + uint32_t lineNumber{}; + const std::string functionName{}; + const std::string dirName{}; + const std::string fileName{}; + + LineInfoValue() = default; + + LineInfoValue(uint32_t lineNumber, const std::string &functionName, + const std::string &dirName, const std::string &fileName) + : lineNumber(lineNumber), functionName(functionName), dirName(dirName), + fileName(fileName) {} + }; + + std::map lineInfo; +}; + +struct ConfigureData { + ConfigureData() = default; + + ~ConfigureData() { + if (stallReasonNames) { + for (size_t i = 0; i < numStallReasons; i++) { + if (stallReasonNames[i]) + std::free(stallReasonNames[i]); + } + std::free(stallReasonNames); + } + if (stallReasonIndices) + std::free(stallReasonIndices); + if (pcSamplingData.pPcData) { + for (size_t i = 0; i < numValidStallReasons; ++i) { + std::free(pcSamplingData.pPcData[i].stallReason); + } + std::free(pcSamplingData.pPcData); + } + } + + void initialize(CUcontext context); + + CUpti_PCSamplingConfigurationInfo configureStallReasons(); + CUpti_PCSamplingConfigurationInfo configureSamplingPeriod(); + CUpti_PCSamplingConfigurationInfo configureSamplingBuffer(); + CUpti_PCSamplingConfigurationInfo configureScratchBuffer(); + CUpti_PCSamplingConfigurationInfo configureHardwareBufferSize(); + CUpti_PCSamplingConfigurationInfo configureStartStopControl(); + CUpti_PCSamplingConfigurationInfo configureCollectionMode(); + + // The amount of data reserved on the GPU + static constexpr size_t HardwareBufferSize = 128 * 1024 * 1024; + // The amount of data copied from the hardware buffer each time + static constexpr size_t ScratchBufferSize = 16 * 1024 * 1024; + // The number of PCs copied from the scratch buffer each time + static constexpr size_t DataBufferPCCount = 1024; + // The sampling period in cycles = 2^frequency + static constexpr uint32_t DefaultFrequency = 10; + + CUcontext context{}; + uint32_t contextId; + uint32_t numStallReasons{}; + uint32_t numValidStallReasons{}; + char **stallReasonNames{}; + uint32_t *stallReasonIndices{}; + std::map stallReasonIndexToMetricIndex{}; + std::set notIssuedStallReasonIndices{}; + CUpti_PCSamplingData pcSamplingData{}; + // The memory storing configuration information has to be kept alive during + // the profiling session + std::vector configurationInfos; +}; + +class CuptiPCSampling : public Singleton { + +public: + CuptiPCSampling() = default; + virtual ~CuptiPCSampling() = default; + + void initialize(CUcontext context); + + void start(CUcontext context); + + void stop(CUcontext context, const DataToEntryMap &dataToEntry); + + void finalize(CUcontext context); + + void loadModule(const char *cubin, size_t cubinSize); + + void unloadModule(const char *cubin, size_t cubinSize); + +private: + ConfigureData *getConfigureData(uint32_t contextId); + + CubinData *getCubinData(uint64_t cubinCrc); + + void processPCSamplingData(ConfigureData *configureData, + const DataToEntryMap &dataToEntry); + + ThreadSafeMap contextIdToConfigureData; + // In case the same cubin is loaded multiple times, we need to keep track of + // all of them + ThreadSafeMap> + cubinCrcToCubinData; + ThreadSafeSet contextInitialized; + + std::atomic pcSamplingStarted{false}; + std::mutex pcSamplingMutex{}; + std::mutex contextMutex{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h b/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h new file mode 100644 index 0000000000..ed471d25c8 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h @@ -0,0 +1,22 @@ +#ifndef PROTON_PROFILER_CUPTI_PROFILER_H_ +#define PROTON_PROFILER_CUPTI_PROFILER_H_ + +#include "Profiler/GPUProfiler.h" + +namespace proton { + +class CuptiProfiler : public GPUProfiler { +public: + CuptiProfiler(); + virtual ~CuptiProfiler(); + +private: + struct CuptiProfilerPimpl; + + virtual void + doSetMode(const std::vector &modeAndOptions) override; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PROFILER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/mthreads/proton/csrc/include/Profiler/GPUProfiler.h new file mode 100644 index 0000000000..1ec86ee1ef --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/GPUProfiler.h @@ -0,0 +1,292 @@ +#ifndef PROTON_PROFILER_GPU_PROFILER_H_ +#define PROTON_PROFILER_GPU_PROFILER_H_ + +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Profiler.h" +#include "Profiler/Graph.h" +#include "Session/Session.h" +#include "Utility/Atomic.h" +#include "Utility/Env.h" +#include "Utility/Map.h" +#include "Utility/Table.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +namespace detail { + +void flushDataPhasesImpl( + const bool periodicFlushEnabled, const std::string &periodicFlushingFormat, + std::map &dataFlushedPhases, + const std::map> + &dataPhases, + PendingGraphPool *pendingGraphPool); + +void updateDataPhases( + std::map> + &dataPhases, + Data *data, size_t phase); + +void setPeriodicFlushingMode(bool &periodicFlushingEnabled, + std::string &periodicFlushingFormat, + const std::vector &modeAndOptions, + const char *profilerName); +} // namespace detail + +// Singleton: Each concrete GPU profiler, e.g., +// CuptiProfiler, should be a singleton. +template +class GPUProfiler : public Profiler, + public OpInterface, + public Singleton { +public: + GPUProfiler() = default; + virtual ~GPUProfiler() = default; + + using CorrIdToExternIdMap = + ThreadSafeMap>; + + struct ExternIdState { + // ----non-graph launch fields---- + DataToEntryMap dataToEntry; + // Sometimes the kernel name cannot be retrieved in application threads + // for reasons like uninitialize CUDA context. + bool isMissingName{true}; + // ----graph launch fields---- + // For graph launches, the launch correlation id fans out into multiple + // kernel activity records. We track the expected fanout here and keep + // updating it when we have processed each kernel activity record. + size_t numNodes{1}; + + struct GraphNodeState { + // If the node is launched as a metric kernel, ignore it's timing data. + bool isMetricNode{false}; + bool isMissingName{true}; + + void setEntry(Data *data, const DataEntry &entry) { + dataToEntry.insert_or_assign(data, entry); + } + + const DataEntry *findEntry(Data *data) const { + auto it = dataToEntry.find(data); + if (it == dataToEntry.end()) + return nullptr; + return &it->second; + } + + template void forEachEntry(FnT &&fn) { + for (auto &[data, entry] : dataToEntry) + fn(data, entry); + } + + DataToEntryMap dataToEntry; + }; + + using GraphNodeStateTable = RangeTable; + + // graphNodeId -> (per-Data entry) + GraphNodeStateTable graphNodeIdToState; + }; + + using ExternIdToStateMap = + ThreadSafeMap>; + +protected: + // OpInterface + void startOp(const Scope &scope) override { + this->threadState.scopeStack.push_back(scope); + for (auto *data : dataSet) { + auto entry = data->addOp(scope.name); + threadState.dataToEntry.insert_or_assign(data, entry); + } + } + + void stopOp(const Scope &scope) override { + this->threadState.scopeStack.pop_back(); + threadState.dataToEntry.clear(); + } + + void flushDataPhases( + std::map &dataFlushedPhases, + const std::map> + &dataPhases, + PendingGraphPool *pendingGraphPool) { + detail::flushDataPhasesImpl(periodicFlushingEnabled, periodicFlushingFormat, + dataFlushedPhases, dataPhases, + pendingGraphPool); + } + + // Profiler + virtual void doStart() override { pImpl->doStart(); } + virtual void doFlush() override { pImpl->doFlush(); } + virtual void doStop() override { pImpl->doStop(); } + virtual void doAddMetrics( + size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) override { + pImpl->doAddMetrics(scopeId, scalarMetrics, tensorMetrics); + } + + struct ThreadState { + ConcreteProfilerT &profiler; + SessionManager &sessionManager = SessionManager::instance(); + std::vector scopeStack; // Used for nvtx range or triton op tracking + DataToEntryMap dataToEntry; + bool isApiExternOp{false}; + bool isStreamCapturing{false}; + bool isMetricKernelLaunching{false}; + + ThreadState(ConcreteProfilerT &profiler) : profiler(profiler) {} + + void enterOp(const Scope &scope) { + if (profiler.isOpInProgress()) // Already in a triton op + return; + // Enter a new GPU API op + isApiExternOp = true; + profiler.enterOp(scope); + } + + void exitOp() { + if (!profiler.isOpInProgress() || !isApiExternOp) + return; + profiler.exitOp(scopeStack.back()); + isApiExternOp = false; + } + + void enterScope(const std::string &name) { + Scope scope(name); + scopeStack.push_back(scope); + sessionManager.enterScope(scope); + } + + void exitScope() { + sessionManager.exitScope(scopeStack.back()); + scopeStack.pop_back(); + } + }; + + struct Correlation { + std::atomic maxSubmittedCorrelationId{0}; + std::atomic maxCompletedCorrelationId{0}; + // Mapping from a native profiler correlation id to an external id. + CorrIdToExternIdMap corrIdToExternId; + // Mapping from an external id to graph-node states + ExternIdToStateMap externIdToState; + + Correlation() = default; + + void submit(uint64_t correlationId) { + atomicMax(maxSubmittedCorrelationId, correlationId); + } + + void complete(uint64_t correlationId) { + atomicMax(maxCompletedCorrelationId, correlationId); + } + + // Correlate the correlationId with the last externId + void correlate(uint64_t correlationId, size_t externId, size_t numNodes, + bool isMissingName, const DataToEntryMap &dataToEntry) { + corrIdToExternId.insert(correlationId, externId); + externIdToState.upsert(externId, [&](ExternIdState &state) { + state.numNodes = numNodes; + state.dataToEntry = dataToEntry; + state.isMissingName = isMissingName; + }); + } + + template + void flush(uint64_t maxRetries, uint64_t sleepUs, FlushFnT &&flushFn) { + flushFn(); + auto submittedId = maxSubmittedCorrelationId.load(); + auto completedId = maxCompletedCorrelationId.load(); + auto retries = maxRetries; + while ((completedId < submittedId) && retries > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(sleepUs)); + flushFn(); + completedId = maxCompletedCorrelationId.load(); + --retries; + } + } + }; + + static thread_local ThreadState threadState; + + std::unique_ptr metricBuffer; + std::unique_ptr pendingGraphPool; + + Correlation correlation; + + // Use the pimpl idiom to hide the implementation details. This lets us avoid + // including the cupti header from this header. The cupti header and the + // equivalent header from AMD define conflicting macros, so we want to use + // those headers only within cpp files. + class GPUProfilerPimplInterface { + public: + GPUProfilerPimplInterface(ConcreteProfilerT &profiler) + : profiler(profiler) {} + virtual ~GPUProfilerPimplInterface() = default; + + virtual void doStart() = 0; + virtual void doFlush() = 0; + virtual void doStop() = 0; + + void + doAddMetrics(size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) { + if (threadState.isStreamCapturing) { // Graph capture mode + threadState.isMetricKernelLaunching = true; + // Launch metric kernels + profiler.metricBuffer->receive( + scalarMetrics, tensorMetrics, profiler.tensorMetricKernel, + profiler.scalarMetricKernel, profiler.metricKernelStream); + threadState.isMetricKernelLaunching = false; + } else { // Eager mode, directly copy + // Populate tensor metrics + auto tensorMetricsHost = + collectTensorMetrics(profiler.metricBuffer->getRuntime(), + tensorMetrics, profiler.metricKernelStream); + auto &dataToEntry = threadState.dataToEntry; + if (dataToEntry.empty()) { + // Add metrics to a specific scope + for (auto *data : profiler.dataSet) { + data->addMetrics(scopeId, scalarMetrics); + data->addMetrics(scopeId, tensorMetricsHost); + } + } else { + // Add metrics to the current op + for (auto [data, entry] : dataToEntry) { + data->addMetrics(entry.phase, entry.id, scalarMetrics); + data->addMetrics(entry.phase, entry.id, tensorMetricsHost); + } + } + } + } + + protected: + ConcreteProfilerT &profiler; + }; + + std::unique_ptr pImpl; + + bool pcSamplingEnabled{false}; + bool periodicFlushingEnabled{false}; + std::string periodicFlushingFormat{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_GPU_PROFILER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Graph.h b/third_party/mthreads/proton/csrc/include/Profiler/Graph.h new file mode 100644 index 0000000000..adaa27bf43 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Graph.h @@ -0,0 +1,124 @@ +#ifndef PROTON_PROFILER_GRAPH_H_ +#define PROTON_PROFILER_GRAPH_H_ + +#include "Context/Context.h" +#include "Data/Metric.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +class Data; +class Runtime; + +struct GraphState { + using Callpath = std::vector; + + struct NodeState { + // Mapping from Data object to captured callpath. + std::map captureContexts; + // A unique id for the graph node + uint64_t nodeId{}; + // Whether the node is missing name + bool isMissingName{}; + // Whether the node is a metric kernel node + bool isMetricNode{}; + }; + + // Capture tag to identify captured call paths + static constexpr const char *captureTag = ""; + using NodeStateRef = std::reference_wrapper; + // Cached per-Data callpath groups: Data -> (callpath -> [nodeStates...]) + std::map>> + dataToCallpathToNodeStates; + // Mapping from node id to node state, has to be ordered based on node id + // which is the order of node creation + std::map nodeIdToState; + // Identify whether a node is a metric kernel node. + // NOTE: This set has to be ordered to match the node creation order. + std::set metricKernelNodeIds; + // If the graph is launched after profiling started, + // we need to throw an error and this error is only thrown once + bool captureStatusChecked{}; + // A unique id for the graph and graphExec instances; they don't overlap + uint32_t graphId{}; + // Total number of GPU kernels launched by this graph + size_t numNodes{1}; +}; + +struct PendingGraphQueue { + struct PendingGraph { + size_t numNodes; + std::map> dataToEntryIds; + }; + + std::vector pendingGraphs; + // The start buffer offset in the metric buffer for this queue + size_t startBufferOffset{}; + // Total number of metric nodes in the pending graphs + size_t numNodes{}; + // Device where the pending graphs are recorded + void *device{}; + // Phase + size_t phase{}; + + explicit PendingGraphQueue(size_t startBufferOffset, size_t phase, + void *device) + : startBufferOffset(startBufferOffset), phase(phase), device(device) {} + + void push(size_t numNodes, + const std::map> &dataToEntryIds) { + pendingGraphs.emplace_back(PendingGraph{numNodes, dataToEntryIds}); + this->numNodes += numNodes; + } +}; + +class PendingGraphPool { +public: + explicit PendingGraphPool(MetricBuffer *metricBuffer) + : metricBuffer(metricBuffer), runtime(metricBuffer->getRuntime()) {} + + void push(size_t phase, + const std::map> &dataToEntryIds, + size_t numNodes); + + // No GPU synchronization, No CPU locks + void peek(size_t phase); + + // Synchronize and flush all pending graph + bool flushAll(); + + // Check if we need to flush all before pushing new pending graph + bool flushIfNeeded(size_t numNodes); + +private: + struct Slot { + mutable std::mutex mutex; + std::optional queue; + }; + + // The current starting buffer offset in the metric buffer + // device -> offset + std::map deviceBufferOffset{}; + // How much remaining capacity in the metric buffer we have + // device -> capacity + std::map deviceRemainingCapacity{}; + MetricBuffer *metricBuffer{}; + Runtime *runtime{}; + mutable std::mutex mutex; + std::map>> pool; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_GRAPH_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h b/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h new file mode 100644 index 0000000000..b1f829beb5 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h @@ -0,0 +1,78 @@ +#ifndef PROTON_PROFILER_INSTRUMENTATION_PROFILER_H_ +#define PROTON_PROFILER_INSTRUMENTATION_PROFILER_H_ + +#include "Context/Context.h" +#include "Device.h" +#include "Metadata.h" +#include "Profiler/Profiler.h" +#include "Runtime/Runtime.h" +#include "TraceDataIO/Parser.h" +#include "Utility/Singleton.h" + +namespace proton { + +class InstrumentationProfiler : public Profiler, + public InstrumentationInterface, + public OpInterface, + public Singleton { +public: + InstrumentationProfiler() = default; + virtual ~InstrumentationProfiler(); + +protected: + // Profiler + virtual void doStart() override; + virtual void doFlush() override; + virtual void doStop() override; + virtual void + doSetMode(const std::vector &modeAndOptions) override; + virtual void doAddMetrics( + size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) override; + + // InstrumentationInterface + void initFunctionMetadata( + uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdNames, + const std::vector> &scopeIdParentIds, + const std::string &metadataPath) override; + void enterInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) override; + void exitInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) override; + + // OpInterface + void startOp(const Scope &scope) override { + for (auto data : dataSet) { + dataToEntryMap.insert_or_assign(data, data->addOp(scope.name)); + } + } + void stopOp(const Scope &scope) override { dataToEntryMap.clear(); } + +private: + std::shared_ptr getParserConfig(uint64_t functionId, + size_t bufferSize) const; + + Runtime *runtime; + // device -> deviceStream + std::map deviceStreams; + std::map modeOptions; + uint8_t *hostBuffer{nullptr}; + // functionId -> scopeId -> scopeName + std::map> functionScopeIdNames; + // functionId -> scopeId -> contexts + std::map>> + functionScopeIdContexts; + ; + // functionId -> functionName + std::map functionNames; + // functionId -> metadata + std::map functionMetadata; + // data -> scopeId + DataToEntryMap dataToEntryMap; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_INSTRUMENTATION_PROFILER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/Metadata.h b/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/Metadata.h new file mode 100644 index 0000000000..eddea3f933 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Instrumentation/Metadata.h @@ -0,0 +1,30 @@ +#ifndef PROTON_PROFILER_INSTRUMENTATION_METADATA_H_ +#define PROTON_PROFILER_INSTRUMENTATION_METADATA_H_ + +#include + +namespace proton { + +class InstrumentationMetadata { + +public: + InstrumentationMetadata(const std::string &metadataPath) + : metadataPath(metadataPath) { + parse(); + } + + size_t getScratchMemorySize() const { return scratchMemorySize; } + + size_t getNumWarps() const { return numWarps; } + +private: + void parse(); + + const std::string metadataPath; + size_t scratchMemorySize{}; + size_t numWarps{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_INSTRUMENTATION_METADATA_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Profiler.h b/third_party/mthreads/proton/csrc/include/Profiler/Profiler.h new file mode 100644 index 0000000000..bbaa81545b --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Profiler.h @@ -0,0 +1,141 @@ +#ifndef PROTON_PROFILER_PROFILER_H_ +#define PROTON_PROFILER_PROFILER_H_ + +#include "Data/Data.h" +#include "Data/Metric.h" +#include "Utility/Singleton.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +/// A profiler contains utilities provided by the profiler library to +/// collect and analyze performance data. +class Profiler : public MetricInterface { +public: + Profiler() = default; + + virtual ~Profiler() = default; + + /// Start the profiler. + /// If the profiler is already started, this function does nothing. + Profiler *start() { + if (!this->started) { + this->started = true; + this->doStart(); + } + return this; + } + + /// Flush the profiler's data from the device to the host. + /// It doesn't stop the profiler. + Profiler *flush() { + this->doFlush(); + // Treat all phases up to currentPhase - 1 as flushed, even if a phase has + // no GPU activity records (i.e., nothing to flush from device to host). + for (auto *data : this->getDataSet()) { + const auto phaseInfo = data->getPhaseInfo(); + if (phaseInfo.current == 0) + continue; + data->completePhase(phaseInfo.current - 1); + } + return this; + } + + /// Stop the profiler. + /// Do real stop if there's no data to collect. + Profiler *stop() { + if (!this->started) { + return this; + } + if (this->dataSet.empty()) { + this->started = false; + this->doStop(); + } + return this; + } + + /// Register a data object to the profiler. + /// A profiler can yield metrics to multiple data objects. + Profiler *registerData(Data *data) { + std::unique_lock lock(mutex); + dataSet.insert(data); + return this; + } + + /// Unregister a data object from the profiler. + Profiler *unregisterData(Data *data) { + std::unique_lock lock(mutex); + dataSet.erase(data); + return this; + } + + /// Get the set of data objects registered to the profiler. + std::set getDataSet() const { + std::shared_lock lock(mutex); + return dataSet; + } + + Profiler *setMode(const std::vector &modeAndOptions) { + std::unique_lock lock(mutex); + this->modeAndOptions = modeAndOptions; + this->doSetMode(modeAndOptions); + return this; + } + + std::vector getMode() const { + std::shared_lock lock(mutex); + return modeAndOptions; + } + + void addMetrics( + size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) override { + std::unique_lock lock(mutex); + this->doAddMetrics(scopeId, scalarMetrics, tensorMetrics); + } + + /// These fields are not persistent, function pointers will be changed + /// when modules and contexts are switched. + /// So we just set them as thread local storage before the application kernel + /// starts or after the application kernel ends. + void setMetricKernels(void *tensorMetricKernel, void *scalarMetricKernel, + void *stream) override { + this->tensorMetricKernel = tensorMetricKernel; + this->scalarMetricKernel = scalarMetricKernel; + this->metricKernelStream = stream; + } + +protected: + virtual void doStart() = 0; + virtual void doFlush() = 0; + virtual void doStop() = 0; + virtual void doSetMode(const std::vector &modeAndOptions) = 0; + virtual void + doAddMetrics(size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics) = 0; + + mutable std::shared_mutex mutex; + std::set dataSet; + static thread_local void *tensorMetricKernel; + static thread_local void *scalarMetricKernel; + static thread_local void *metricKernelStream; + +private: + bool started{}; + std::vector modeAndOptions{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_PROFILER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h b/third_party/mthreads/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h new file mode 100644 index 0000000000..bb79e1dfed --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h @@ -0,0 +1,22 @@ +#ifndef PROTON_PROFILER_ROCTRACER_PROFILER_H_ +#define PROTON_PROFILER_ROCTRACER_PROFILER_H_ + +#include "Profiler/GPUProfiler.h" + +namespace proton { + +class RoctracerProfiler : public GPUProfiler { +public: + RoctracerProfiler(); + virtual ~RoctracerProfiler(); + +private: + struct RoctracerProfilerPimpl; + + virtual void + doSetMode(const std::vector &modeAndOptions) override; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_ROCTRACER_PROFILER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Proton.h b/third_party/mthreads/proton/csrc/include/Proton.h new file mode 100644 index 0000000000..92e2fdf0ae --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Proton.h @@ -0,0 +1,9 @@ +#ifndef PROTON_H_ +#define PROTON_H_ + +#include "Context/Context.h" +#include "Data/Data.h" +#include "Data/Metric.h" +#include "Session/Session.h" + +#endif // PROTON_H_ diff --git a/third_party/mthreads/proton/csrc/include/Runtime/CudaRuntime.h b/third_party/mthreads/proton/csrc/include/Runtime/CudaRuntime.h new file mode 100644 index 0000000000..62a8dd2f1b --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Runtime/CudaRuntime.h @@ -0,0 +1,42 @@ +#ifndef PROTON_RUNTIME_CUDA_RUNTIME_H_ +#define PROTON_RUNTIME_CUDA_RUNTIME_H_ + +#include "Runtime.h" +#include "Utility/Singleton.h" + +namespace proton { + +class CudaRuntime : public Singleton, public Runtime { +public: + CudaRuntime() : Runtime(DeviceType::CUDA) {} + ~CudaRuntime() = default; + + void launchKernel(void *kernel, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, + unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, void *stream, + void **kernelParams, void **extra) override; + void memset(void *devicePtr, uint32_t value, size_t size, + void *stream) override; + void allocateHostBuffer(uint8_t **buffer, size_t size, bool mapped) override; + void getHostDevicePointer(uint8_t *hostPtr, uint8_t **devicePtr) override; + void freeHostBuffer(uint8_t *buffer) override; + void allocateDeviceBuffer(uint8_t **buffer, size_t size) override; + void freeDeviceBuffer(uint8_t *buffer) override; + void copyDeviceToHostAsync(void *dst, const void *src, size_t size, + void *stream) override; + void *getDevice() override; + void *getPriorityStream() override; + void synchronizeStream(void *stream) override; + void synchronizeDevice() override; + void destroyStream(void *stream) override; + void + processHostBuffer(uint8_t *hostBuffer, size_t hostBufferSize, + uint8_t *deviceBuffer, size_t deviceBufferSize, + void *stream, + std::function callback) override; +}; + +} // namespace proton + +#endif // PROTON_RUNTIME_CUDA_RUNTIME_H_ diff --git a/third_party/mthreads/proton/csrc/include/Runtime/HipRuntime.h b/third_party/mthreads/proton/csrc/include/Runtime/HipRuntime.h new file mode 100644 index 0000000000..0d2934f3cf --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Runtime/HipRuntime.h @@ -0,0 +1,42 @@ +#ifndef PROTON_RUNTIME_HIP_RUNTIME_H_ +#define PROTON_RUNTIME_HIP_RUNTIME_H_ + +#include "Runtime.h" +#include "Utility/Singleton.h" + +namespace proton { + +class HipRuntime : public Singleton, public Runtime { +public: + HipRuntime() : Runtime(DeviceType::HIP) {} + ~HipRuntime() = default; + + void launchKernel(void *kernel, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, + unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, void *stream, + void **kernelParams, void **extra) override; + void memset(void *devicePtr, uint32_t value, size_t size, + void *stream) override; + void allocateHostBuffer(uint8_t **buffer, size_t size, bool mapped) override; + void getHostDevicePointer(uint8_t *hostPtr, uint8_t **devicePtr) override; + void freeHostBuffer(uint8_t *buffer) override; + void allocateDeviceBuffer(uint8_t **buffer, size_t size) override; + void freeDeviceBuffer(uint8_t *buffer) override; + void copyDeviceToHostAsync(void *dst, const void *src, size_t size, + void *stream) override; + void *getDevice() override; + void *getPriorityStream() override; + void synchronizeStream(void *stream) override; + void synchronizeDevice() override; + void destroyStream(void *stream) override; + void + processHostBuffer(uint8_t *hostBuffer, size_t hostBufferSize, + uint8_t *deviceBuffer, size_t deviceBufferSize, + void *stream, + std::function callback) override; +}; + +} // namespace proton + +#endif // PROTON_RUNTIME_HIP_RUNTIME_H_ diff --git a/third_party/mthreads/proton/csrc/include/Runtime/Runtime.h b/third_party/mthreads/proton/csrc/include/Runtime/Runtime.h new file mode 100644 index 0000000000..a09deaac31 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Runtime/Runtime.h @@ -0,0 +1,66 @@ +#ifndef PROTON_RUNTIME_RUNTIME_H_ +#define PROTON_RUNTIME_RUNTIME_H_ + +#include +#include +#include + +#include "Device.h" + +namespace proton { + +/// Abstract base class for different runtime implementations +class Runtime { +public: + Runtime(DeviceType deviceType) : deviceType(deviceType) {} + virtual ~Runtime() = default; + + virtual void launchKernel(void *kernel, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, + void *stream, void **kernelParams, + void **extra) = 0; + + virtual void memset(void *devicePtr, uint32_t value, size_t size, + void *stream) = 0; + + virtual void allocateHostBuffer(uint8_t **buffer, size_t size, + bool mapped = false) = 0; + + virtual void getHostDevicePointer(uint8_t *hostPtr, uint8_t **devicePtr) = 0; + + virtual void freeHostBuffer(uint8_t *buffer) = 0; + + virtual void allocateDeviceBuffer(uint8_t **buffer, size_t size) = 0; + + virtual void freeDeviceBuffer(uint8_t *buffer) = 0; + + virtual void copyDeviceToHostAsync(void *dst, const void *src, size_t size, + void *stream) = 0; + + virtual void *getDevice() = 0; + + virtual void *getPriorityStream() = 0; + + virtual void destroyStream(void *stream) = 0; + + virtual void synchronizeStream(void *stream) = 0; + + virtual void synchronizeDevice() = 0; + + virtual void + processHostBuffer(uint8_t *hostBuffer, size_t hostBufferSize, + uint8_t *deviceBuffer, size_t deviceBufferSize, + void *stream, + std::function callback) = 0; + + DeviceType getDeviceType() const { return deviceType; } + +protected: + DeviceType deviceType; +}; + +} // namespace proton + +#endif // PROTON_RUNTIME_RUNTIME_H_ diff --git a/third_party/mthreads/proton/csrc/include/Session/Session.h b/third_party/mthreads/proton/csrc/include/Session/Session.h new file mode 100644 index 0000000000..e089c6cabf --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Session/Session.h @@ -0,0 +1,241 @@ +#ifndef PROTON_SESSION_SESSION_H_ +#define PROTON_SESSION_SESSION_H_ + +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Utility/Singleton.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +class Profiler; +class Data; + +/// A session is a collection of profiler, context source, and data objects. +/// There could be multiple sessions in the system, each can correspond to a +/// different duration, or the same duration but with different configurations. +class Session { +public: + ~Session() = default; + + void activate(); + + void deactivate(bool flushing); + + void finalize(const std::string &outputFormat); + + size_t getContextDepth(); + + Profiler *getProfiler() const { return profiler; } + +private: + Session(size_t id, const std::string &path, Profiler *profiler, + std::unique_ptr contextSource, + std::unique_ptr data) + : id(id), path(path), profiler(profiler), + contextSource(std::move(contextSource)), data(std::move(data)) {} + + template std::vector getInterfaces() { + std::vector interfaces; + // There's an implicit order between contextSource and profiler/data. The + // latter two rely on the contextSource to obtain the context, so we need to + // add the contextSource first. + if (auto interface = dynamic_cast(contextSource.get())) { + interfaces.push_back(interface); + } + if (auto interface = dynamic_cast(profiler)) { + interfaces.push_back(interface); + } + if (auto interface = dynamic_cast(data.get())) { + interfaces.push_back(interface); + } + return interfaces; + } + + const std::string path{}; + size_t id{}; + Profiler *profiler{}; + std::unique_ptr contextSource{}; + std::unique_ptr data{}; + + friend class SessionManager; +}; + +/// A session manager is responsible for managing the lifecycle of sessions. +/// There's a single and unique session manager in the system. +class SessionManager : public Singleton { +public: + SessionManager() = default; + ~SessionManager() = default; + + size_t addSession(const std::string &path, const std::string &profilerName, + const std::string &contextSourceName, + const std::string &dataName, const std::string &mode); + + void finalizeSession(size_t sessionId, const std::string &outputFormat); + + void finalizeAllSessions(const std::string &outputFormat); + + void activateSession(size_t sessionId); + + void activateAllSessions(); + + void deactivateSession(size_t sessionId, bool flushing); + + void deactivateAllSessions(bool flushing); + + size_t getContextDepth(size_t sessionId); + + std::vector getDataMsgPack(size_t sessionId, size_t phase); + + std::string getData(size_t sessionId, size_t phase); + + void clearData(size_t sessionId, size_t phase, bool clearUpToPhase = false); + + size_t advanceDataPhase(size_t sessionId); + + bool isDataPhaseComplete(size_t sessionId, size_t phase); + + void enterScope(const Scope &scope); + + void exitScope(const Scope &scope); + + void enterOp(const Scope &scope); + + void exitOp(const Scope &scope); + + void initFunctionMetadata( + uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdNames, + const std::vector> &scopeIdParents, + const std::string &metadataPath); + + void enterInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size); + + void exitInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size); + + void addMetrics(size_t scopeId, + const std::map &scalarMetrics, + const std::map &tensorMetrics); + + void setMetricKernels(void *tensorMetricKernel, void *scalarMetricKernel, + void *stream); + + void setState(std::optional context); + +private: + Profiler *validateAndSetProfilerMode(Profiler *profiler, + const std::string &mode); + + std::unique_ptr makeSession(size_t id, const std::string &path, + const std::string &profilerName, + const std::string &contextSourceName, + const std::string &dataName, + const std::string &mode); + + Session *getSessionOrThrow(size_t sessionId); + + void activateSessionImpl(size_t sessionId); + + void deActivateSessionImpl(size_t sessionId, bool flushing); + + size_t getSessionId(const std::string &path) { return sessionPaths[path]; } + + bool hasSession(const std::string &path) { + return sessionPaths.find(path) != sessionPaths.end(); + } + + bool hasSession(size_t sessionId) { + return sessions.find(sessionId) != sessions.end(); + } + + void removeSession(size_t sessionId); + + template + void updateInterfaceCount(size_t sessionId, Counter &interfaceCounts) { + auto interfaces = sessions[sessionId]->getInterfaces(); + for (auto *interface : interfaces) { + auto it = std::find_if( + interfaceCounts.begin(), interfaceCounts.end(), + [interface](const auto &pair) { return pair.first == interface; }); + + if (it != interfaceCounts.end()) { + if constexpr (isRegistering) { + ++it->second; + } else { + --it->second; + if (it->second == 0) { + interfaceCounts.erase(it); + } + } + } else if constexpr (isRegistering) { + interfaceCounts.emplace_back(interface, 1); + } + } + } + + template + void registerInterface(size_t sessionId, Counter &interfaceCounts) { + updateInterfaceCount(sessionId, interfaceCounts); + } + + template + void unregisterInterface(size_t sessionId, Counter &interfaceCounts) { + updateInterfaceCount(sessionId, interfaceCounts); + } + + template + void executeInterface(Counter &interfaceCounts, FnT &&fn, + bool isReversed = false) { + auto process = [&](auto &entry) { + if (entry.second > 0) { + fn(entry.first); + } + }; + + if (isReversed) { + for (auto it = interfaceCounts.rbegin(); it != interfaceCounts.rend(); + ++it) { + process(*it); + } + } else { + for (auto &entry : interfaceCounts) { + process(entry); + } + } + } + + mutable std::mutex mutex; + + size_t nextSessionId{}; + // path -> session id + std::map sessionPaths; + // session id -> active + std::map sessionActive; + // session id -> session + std::map> sessions; + // {scope, active count} + std::vector> scopeInterfaceCounts; + // {op, active count} + std::vector> opInterfaceCounts; + // {instrumentation, active count} + std::vector> + instrumentationInterfaceCounts; + // {metric, active count} + std::vector> metricInterfaceCounts; + // {context source, active count} + std::vector> contextSourceCounts; +}; + +} // namespace proton + +#endif // PROTON_SESSION_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Atomic.h b/third_party/mthreads/proton/csrc/include/Utility/Atomic.h new file mode 100644 index 0000000000..0f759e0d61 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Atomic.h @@ -0,0 +1,39 @@ +#ifndef PROTON_UTILITY_ATOMIC_H_ +#define PROTON_UTILITY_ATOMIC_H_ + +#include +#include + +namespace proton { + +template T atomicMax(std::atomic &target, T value) { + T current = target.load(); + while (current < value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +template T atomicMin(std::atomic &target, T value) { + T current = target.load(); + while (current > value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +template +void doubleCheckedLock(Condition enterCondition, std::mutex &lock, + Function function) { + if (!enterCondition()) + return; + + std::unique_lock guard(lock); + + if (!enterCondition()) + return; + + function(); +} + +} // namespace proton + +#endif // PROTON_UTILITY_ATOMIC_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Env.h b/third_party/mthreads/proton/csrc/include/Utility/Env.h new file mode 100644 index 0000000000..f4f9ad056b --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Env.h @@ -0,0 +1,40 @@ +#ifndef PROTON_UTILITY_ENV_H_ +#define PROTON_UTILITY_ENV_H_ + +#include +#include +#include +#include + +namespace proton { + +static std::mutex getenv_mutex; + +inline int64_t getIntEnv(const std::string &env, int64_t defaultValue) { + std::lock_guard lock(getenv_mutex); + const char *s = std::getenv(env.c_str()); + if (s == nullptr) + return defaultValue; + return std::stoll(s); +} + +inline bool getBoolEnv(const std::string &env, bool defaultValue) { + std::lock_guard lock(getenv_mutex); + const char *s = std::getenv(env.c_str()); + if (s == nullptr) + return defaultValue; + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::string getStrEnv(const std::string &env) { + std::lock_guard lock(getenv_mutex); + const char *s = std::getenv(env.c_str()); + return std::string(s ? s : ""); +} + +} // namespace proton + +#endif // PROTON_UTILITY_ENV_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Errors.h b/third_party/mthreads/proton/csrc/include/Utility/Errors.h new file mode 100644 index 0000000000..09c44025dc --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Errors.h @@ -0,0 +1,15 @@ +#ifndef PROTON_UTILITY_ERRORS_H_ +#define PROTON_UTILITY_ERRORS_H_ + +#include + +namespace proton { + +class NotImplemented : public std::logic_error { +public: + NotImplemented() : std::logic_error("Not yet implemented") {}; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_ERRORS_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Map.h b/third_party/mthreads/proton/csrc/include/Utility/Map.h new file mode 100644 index 0000000000..81af3890bd --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Map.h @@ -0,0 +1,107 @@ +#ifndef PROTON_UTILITY_MAP_H_ +#define PROTON_UTILITY_MAP_H_ + +#include +#include +#include +#include + +namespace proton { + +/// A simple thread safe map with read/write lock. +template > +class ThreadSafeMap { +public: + ThreadSafeMap() = default; + + template void upsert(const Key &key, FnT &&fn) { + std::unique_lock lock(mutex); + fn(map[key]); + } + + template bool withRead(const Key &key, FnT &&fn) const { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) { + return false; + } + fn(it->second); + return true; + } + + template bool withWrite(const Key &key, FnT &&fn) { + std::unique_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) { + return false; + } + fn(it->second); + return true; + } + + Value &operator[](const Key &key) { + std::unique_lock lock(mutex); + return map[key]; + } + + Value &operator[](Key &&key) { + std::unique_lock lock(mutex); + return map[std::move(key)]; + } + + Value &at(const Key &key) { + std::shared_lock lock(mutex); + return map.at(key); + } + + Value &at(const Key &key) const { + std::shared_lock lock(mutex); + return map.at(key); + } + + void insert(const Key &key, const Value &value) { + std::unique_lock lock(mutex); + map[key] = value; + } + + bool contain(const Key &key) const { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return map.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + map.clear(); + } + + size_t size() const { + std::shared_lock lock(mutex); + return map.size(); + } + + std::optional> find(const Key &key) { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) { + return std::nullopt; + } + return std::ref(it->second); + } + +private: + Container map; + mutable std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/MsgPackWriter.h b/third_party/mthreads/proton/csrc/include/Utility/MsgPackWriter.h new file mode 100644 index 0000000000..639b4cdfad --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/MsgPackWriter.h @@ -0,0 +1,33 @@ +#ifndef PROTON_UTILITY_MSGPACK_WRITER_H_ +#define PROTON_UTILITY_MSGPACK_WRITER_H_ + +#include +#include +#include +#include + +namespace proton { + +// See https://msgpack.org/index.html for the specification. +class MsgPackWriter { +public: + void reserve(size_t bytes); + + std::vector take() &&; + + void packNil(); + void packBool(bool value); + void packUInt(uint64_t value); + void packInt(int64_t value); + void packDouble(double value); + void packStr(std::string_view value); + void packArray(uint32_t size); + void packMap(uint32_t size); + +private: + std::vector out; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MSGPACK_WRITER_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Numeric.h b/third_party/mthreads/proton/csrc/include/Utility/Numeric.h new file mode 100644 index 0000000000..61de2aa305 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Numeric.h @@ -0,0 +1,21 @@ +#ifndef PROTON_UTILITY_NUMERIC_H_ +#define PROTON_UTILITY_NUMERIC_H_ + +#include + +namespace proton { + +template constexpr T nextPowerOfTwo(T value) { + if (value < 1) { + return 1; + } + --value; // Decrement to handle the case where value is already a power of two + for (size_t i = 1; i < sizeof(T) * 8; i <<= 1) { + value |= value >> i; // Propagate the highest set bit to the right + } + return value + 1; // Increment to get the next power of two +} + +} // namespace proton + +#endif // PROTON_UTILITY_NUMERIC_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Set.h b/third_party/mthreads/proton/csrc/include/Utility/Set.h new file mode 100644 index 0000000000..7f2184b7cf --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Set.h @@ -0,0 +1,45 @@ +#ifndef PROTON_UTILITY_SET_H_ +#define PROTON_UTILITY_SET_H_ + +#include +#include + +namespace proton { + +/// A simple thread safe set with read/write lock. +template > +class ThreadSafeSet { +public: + ThreadSafeSet() = default; + + void insert(const Key &key) { + std::unique_lock lock(mutex); + set.insert(key); + } + + bool contain(const Key &key) const { + std::shared_lock lock(mutex); + auto it = set.find(key); + if (it == set.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return set.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + set.clear(); + } + +private: + Container set; + mutable std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Singleton.h b/third_party/mthreads/proton/csrc/include/Utility/Singleton.h new file mode 100644 index 0000000000..f91fef1437 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Singleton.h @@ -0,0 +1,22 @@ +#ifndef PROTON_UTILITY_SINGLETON_H_ +#define PROTON_UTILITY_SINGLETON_H_ + +namespace proton { + +template class Singleton { +public: + Singleton(const Singleton &) = delete; + Singleton &operator=(const Singleton &) = delete; + + static T &instance() { + static T _; + return _; + } + +protected: + Singleton() = default; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_SINGLETON_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/String.h b/third_party/mthreads/proton/csrc/include/Utility/String.h new file mode 100644 index 0000000000..219d251ab3 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/String.h @@ -0,0 +1,69 @@ +#ifndef PROTON_UTILITY_STRING_H_ +#define PROTON_UTILITY_STRING_H_ + +#include +#include + +namespace proton { + +inline std::string toLower(const std::string &str) { + std::string lower; + for (auto c : str) { + lower += tolower(c); + } + return lower; +} + +inline std::string replace(const std::string &str, const std::string &src, + const std::string &dst) { + std::string replaced = str; + size_t pos = replaced.find(src); + while (pos != std::string::npos) { + replaced.replace(pos, src.length(), dst); + pos += dst.length(); + pos = replaced.find(src, pos); + } + return replaced; +} + +inline bool endWith(const std::string &str, const std::string &sub) { + if (str.length() < sub.length()) { + return false; + } + return str.compare(str.length() - sub.length(), sub.length(), sub) == 0; +} + +inline std::string trim(const std::string &str) { + size_t start = 0; + size_t end = str.length(); + while (start < end && isspace(str[start])) { + start++; + } + while (end > start && isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + +inline std::vector split(const std::string &str, + const std::string &delim) { + std::vector result; + size_t start = 0; + size_t end = str.find(delim); + while (end != std::string::npos) { + result.push_back(str.substr(start, end - start)); + start = end + delim.length(); + end = str.find(delim, start); + } + result.push_back(str.substr(start, end)); + return result; +} + +inline std::string formatFileLineFunction(const std::string &file, int line, + const std::string &function) { + return file + ":" + std::to_string(line) + "@" + function; +} + +} // namespace proton + +#endif // PROTON_UTILITY_STRING_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Table.h b/third_party/mthreads/proton/csrc/include/Utility/Table.h new file mode 100644 index 0000000000..830eb81cba --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Table.h @@ -0,0 +1,83 @@ +#ifndef PROTON_UTILITY_TABLE_H_ +#define PROTON_UTILITY_TABLE_H_ + +#include +#include +#include +#include +#include + +namespace proton { + +// Dense table for ids in a contiguous range [minId, maxId]. +template class RangeTable { + static_assert(std::is_integral_v, "RangeTable IdT must be integral"); + +public: + void resetRange(IdT minIdValue, IdT maxIdValue) { + if (maxIdValue < minIdValue) { + clear(); + return; + } + minId = minIdValue; + auto size = static_cast(maxIdValue - minIdValue + 1); + nodes.clear(); + nodes.resize(size); + present.assign(size, false); + } + + void clear() { + minId = 0; + nodes.clear(); + present.clear(); + } + + std::pair tryEmplace(IdT id) { + if (!inRange(id)) + return {nullptr, false}; + auto index = indexFor(id); + bool inserted = !present[index]; + present[index] = true; + return {&nodes[index], inserted}; + } + + T &emplace(IdT id) { + auto index = indexFor(id); + present[index] = true; + return nodes[index]; + } + + T *find(IdT id) { + if (!inRange(id)) + return nullptr; + auto index = indexFor(id); + return present[index] ? &nodes[index] : nullptr; + } + + const T *find(IdT id) const { + if (!inRange(id)) + return nullptr; + auto index = indexFor(id); + return present[index] ? &nodes[index] : nullptr; + } + + bool empty() const { return nodes.empty(); } + +private: + bool inRange(IdT id) const { + if (nodes.empty() || id < minId) + return false; + auto offset = static_cast(id - minId); + return offset < nodes.size(); + } + + size_t indexFor(IdT id) const { return static_cast(id - minId); } + + IdT minId{0}; + std::vector nodes; + std::vector present; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_TABLE_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Traits.h b/third_party/mthreads/proton/csrc/include/Utility/Traits.h new file mode 100644 index 0000000000..b7b073f278 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Traits.h @@ -0,0 +1,33 @@ +#ifndef PROTON_UTILITY_TRAITS_H_ +#define PROTON_UTILITY_TRAITS_H_ + +#include +#include + +namespace proton { + +namespace details { + +template struct variant_index; + +template struct variant_index> { + static constexpr std::size_t value = []() constexpr { + std::size_t i = 0; + (void)((std::is_same_v ? true : (++i, false)) || ...); + return i; + }(); +}; + +} // namespace details +template +struct is_one_of : std::disjunction...> {}; + +template struct always_false : std::false_type {}; + +template +inline constexpr std::size_t variant_index_v = + details::variant_index::value; + +} // namespace proton + +#endif // PROTON_UTILITY_TRAITS_H_ diff --git a/third_party/mthreads/proton/csrc/include/Utility/Vector.h b/third_party/mthreads/proton/csrc/include/Utility/Vector.h new file mode 100644 index 0000000000..651e1bea29 --- /dev/null +++ b/third_party/mthreads/proton/csrc/include/Utility/Vector.h @@ -0,0 +1,82 @@ +#ifndef PROTON_UTILITY_VECTOR_H_ +#define PROTON_UTILITY_VECTOR_H_ + +#include +#include +#include +#include + +namespace proton { + +/// A simple thread safe vector with read/write lock. +template > +class ThreadSafeVector { +public: + ThreadSafeVector() = default; + + void push_back(const Value &value) { + std::unique_lock lock(mutex); + vector.push_back(value); + } + + void push_back(Value &&value) { + std::unique_lock lock(mutex); + vector.push_back(std::move(value)); + } + + template void emplace_back(Args &&...args) { + std::unique_lock lock(mutex); + vector.emplace_back(std::forward(args)...); + } + + bool contain(const Value &value) { + std::shared_lock lock(mutex); + return std::find(vector.begin(), vector.end(), value) != vector.end(); + } + + bool erase(const Value &value) { + std::unique_lock lock(mutex); + auto it = std::find(vector.begin(), vector.end(), value); + if (it == vector.end()) + return false; + vector.erase(it); + return true; + } + + bool pop_back(Value &value) { + std::unique_lock lock(mutex); + if (vector.empty()) + return false; + value = vector.back(); + vector.pop_back(); + return true; + } + + void clear() { + std::unique_lock lock(mutex); + vector.clear(); + } + + size_t size() { + std::shared_lock lock(mutex); + return vector.size(); + } + + bool empty() { + std::shared_lock lock(mutex); + return vector.empty(); + } + + Container snapshot() { + std::shared_lock lock(mutex); + return vector; + } + +private: + Container vector; + std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_VECTOR_H_ diff --git a/third_party/mthreads/proton/csrc/lib/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/CMakeLists.txt new file mode 100644 index 0000000000..05312eb5b5 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/CMakeLists.txt @@ -0,0 +1,7 @@ +add_subdirectory(Context) +add_subdirectory(Data) +add_subdirectory(Utility) +add_subdirectory(Driver) +add_subdirectory(Runtime) +add_subdirectory(Profiler) +add_subdirectory(Session) diff --git a/third_party/mthreads/proton/csrc/lib/Context/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Context/CMakeLists.txt new file mode 100644 index 0000000000..456c04b115 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Context/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(ProtonContext + Context.cpp + Python.cpp + Shadow.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Context/Context.cpp b/third_party/mthreads/proton/csrc/lib/Context/Context.cpp new file mode 100644 index 0000000000..ba3d67f44e --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Context/Context.cpp @@ -0,0 +1,12 @@ +#include "Context/Context.h" + +namespace proton { + +/*static*/ thread_local std::optional ContextSource::state = + std::nullopt; + +std::atomic Scope::scopeIdCounter{1}; + +/*static*/ thread_local std::map OpInterface::opInProgress; + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Context/Python.cpp b/third_party/mthreads/proton/csrc/lib/Context/Python.cpp new file mode 100644 index 0000000000..d4dbadae89 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Context/Python.cpp @@ -0,0 +1,83 @@ +#include "Context/Python.h" +#include "Utility/String.h" +#include "pybind11/pybind11.h" +#include +#include + +namespace proton { + +namespace { + +// bpo-42262 added Py_NewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) +PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef((PyObject *)(obj)) +#endif + +// bpo-42262 added Py_XNewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) +PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef((PyObject *)(obj)) +#endif + +PyCodeObject *getFrameCodeObject(PyFrameObject *frame) { + assert(frame != nullptr); + return PyFrame_GetCode(frame); +} + +PyFrameObject *getFrameBack(PyFrameObject *frame) { + assert(frame != nullptr); + return PyFrame_GetBack(frame); +} + +std::string unpackPyobject(PyObject *pyObject) { + if (PyBytes_Check(pyObject)) { + size_t size = PyBytes_GET_SIZE(pyObject); + return std::string(PyBytes_AS_STRING(pyObject), size); + } + if (PyUnicode_Check(pyObject)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + Py_ssize_t size; + const char *data = PyUnicode_AsUTF8AndSize(pyObject, &size); + if (!data) { + return ""; + } + return std::string(data, (size_t)size); + } + return ""; +} + +} // namespace + +std::vector PythonContextSource::getContextsImpl() { + pybind11::gil_scoped_acquire gil; + + PyFrameObject *frame = PyEval_GetFrame(); + Py_XINCREF(frame); + + std::vector contexts; + while (frame != nullptr) { + PyCodeObject *f_code = getFrameCodeObject(frame); + size_t lineno = PyFrame_GetLineNumber(frame); + size_t firstLineNo = f_code->co_firstlineno; + std::string file = unpackPyobject(f_code->co_filename); + std::string function = unpackPyobject(f_code->co_name); + auto pythonFrame = formatFileLineFunction(file, lineno, function); + contexts.push_back(Context(pythonFrame)); + auto newFrame = getFrameBack(frame); + Py_DECREF(frame); + frame = newFrame; + } + std::reverse(contexts.begin(), contexts.end()); + return contexts; +} + +size_t PythonContextSource::getDepth() { return getContextsImpl().size(); } + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Context/Shadow.cpp b/third_party/mthreads/proton/csrc/lib/Context/Shadow.cpp new file mode 100644 index 0000000000..70366e6b99 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Context/Shadow.cpp @@ -0,0 +1,52 @@ +#include "Context/Shadow.h" + +#include +#include + +namespace proton { + +void ShadowContextSource::initializeThreadContext() { + if (!threadContextInitialized[this]) { + threadContextStack[this] = *mainContextStack; + threadContextInitialized[this] = true; + } +} + +void ShadowContextSource::enterScope(const Scope &scope) { + initializeThreadContext(); + threadContextStack[this].push_back(scope); +} + +std::vector ShadowContextSource::getContextsImpl() { + initializeThreadContext(); + return threadContextStack[this]; +} + +size_t ShadowContextSource::getDepth() { + initializeThreadContext(); + return threadContextStack[this].size(); +} + +void ShadowContextSource::exitScope(const Scope &scope) { + if (threadContextStack[this].empty()) { + throw std::runtime_error("Context stack is empty"); + } + if (threadContextStack[this].back() != scope) { + throw std::runtime_error("Context stack is not balanced"); + } + threadContextStack[this].pop_back(); +} + +void ShadowContextSource::clear() { + ContextSource::clear(); + threadContextStack[this].clear(); + threadContextInitialized[this] = false; +} + +/*static*/ thread_local std::map + ShadowContextSource::threadContextInitialized; + +/*static*/ thread_local std::map> + ShadowContextSource::threadContextStack; + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Data/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Data/CMakeLists.txt new file mode 100644 index 0000000000..4fa4282a23 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Data/CMakeLists.txt @@ -0,0 +1,6 @@ +add_proton_library(ProtonData + Data.cpp + Metric.cpp + TraceData.cpp + TreeData.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Data/Data.cpp b/third_party/mthreads/proton/csrc/lib/Data/Data.cpp new file mode 100644 index 0000000000..76fd4d8fc5 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Data/Data.cpp @@ -0,0 +1,120 @@ +#include "Data/Data.h" +#include "Utility/String.h" + +#include +#include +#include + +#include + +namespace proton { + +void Data::initPhaseStore(PhaseStoreBase &store) { + phaseStore = &store; + currentPhasePtr = phaseStore->createPtr(0); + activePhases.insert(0); +} + +size_t Data::advancePhase() { + std::unique_lock lock(mutex); + const auto nextPhase = currentPhase.load(std::memory_order_relaxed) + 1; + currentPhasePtr = phaseStore->createPtr(nextPhase); + activePhases.insert(nextPhase); + currentPhase.store(nextPhase, std::memory_order_release); + return nextPhase; +} + +void Data::clear(size_t phase, bool clearUpToPhase) { + // No locking needed. + // If phase == currentPhase, we expect users to call clear right after + // deactivating the profiler, without any GPU events in between. + // If phase < currentPhase, clearing a past phase is safe without locks. + if (clearUpToPhase) + phaseStore->clearUpToInclusive(phase); + else + phaseStore->clearPhase(phase); + + std::unique_lock lock(mutex); + if (clearUpToPhase) { + for (auto it = activePhases.begin(); it != activePhases.end();) { + if (*it <= phase) { + it = activePhases.erase(it); + } else { + ++it; + } + } + } else { + activePhases.erase(phase); + } + + // In case the current phase is cleared, recreate its pointer. + const auto phaseToRecreate = currentPhase.load(std::memory_order_relaxed); + currentPhasePtr = phaseStore->createPtr(phaseToRecreate); + activePhases.insert(phaseToRecreate); +} + +void Data::completePhase(size_t phase) { + std::unique_lock lock(mutex); + if (completeUpToPhase == kNoCompletePhase || phase > completeUpToPhase) + completeUpToPhase = phase; +} + +Data::PhaseInfo Data::getPhaseInfo() const { + std::shared_lock lock(mutex); + return PhaseInfo{currentPhase.load(std::memory_order_relaxed), + completeUpToPhase}; +} + +void Data::dump(const std::string &outputFormat) { + std::shared_lock lock(mutex); + + OutputFormat outputFormatEnum = outputFormat.empty() + ? getDefaultOutputFormat() + : parseOutputFormat(outputFormat); + + for (auto phase : activePhases) { + std::unique_ptr out; + if (path.empty() || path == "-") { + out.reset(new std::ostream(std::cout.rdbuf())); // Redirecting to cout + } else { + auto suffix = currentPhase.load(std::memory_order_relaxed) == 0 + ? "" + : ".part_" + std::to_string(phase); + const auto filePath = + path + suffix + "." + outputFormatToString(outputFormatEnum); + const auto fileMode = + (outputFormatEnum == OutputFormat::HatchetMsgPack) + ? (std::ios::out | std::ios::binary | std::ios::trunc) + : (std::ios::out | std::ios::trunc); + out.reset( + new std::ofstream(filePath, fileMode)); // Opening a file for output + } + doDump(*out, outputFormatEnum, phase); + } +} + +OutputFormat parseOutputFormat(const std::string &outputFormat) { + if (toLower(outputFormat) == "hatchet") { + return OutputFormat::Hatchet; + } else if (toLower(outputFormat) == "hatchet_msgpack") { + return OutputFormat::HatchetMsgPack; + } else if (toLower(outputFormat) == "chrome_trace") { + return OutputFormat::ChromeTrace; + } else { + throw std::runtime_error("Unknown output format: " + outputFormat); + } +} + +const std::string outputFormatToString(OutputFormat outputFormat) { + if (outputFormat == OutputFormat::Hatchet) { + return "hatchet"; + } else if (outputFormat == OutputFormat::HatchetMsgPack) { + return "hatchet_msgpack"; + } else if (outputFormat == OutputFormat::ChromeTrace) { + return "chrome_trace"; + } + throw std::runtime_error("Unknown output format: " + + std::to_string(static_cast(outputFormat))); +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Data/Metric.cpp b/third_party/mthreads/proton/csrc/lib/Data/Metric.cpp new file mode 100644 index 0000000000..e562447558 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Data/Metric.cpp @@ -0,0 +1,199 @@ +#include "Data/Metric.h" + +#include +#include +#include + +namespace proton { + +std::map + MetricBuffer::metricDescriptors; +std::map MetricBuffer::metricNameToId; +std::shared_mutex MetricBuffer::metricDescriptorMutex; + +std::atomic MetricBuffer::metricId{0}; + +MetricBuffer::~MetricBuffer() { + for (auto &[device, buffer] : deviceBuffers) { + runtime->freeHostBuffer(buffer.hostPtr); + runtime->freeHostBuffer(reinterpret_cast(buffer.hostOffset)); + if (!mappedHostBuffer) { + runtime->freeDeviceBuffer(buffer.devicePtr); + runtime->freeDeviceBuffer(buffer.deviceOffsetPtr); + } + if (buffer.priorityStream) { + runtime->destroyStream(buffer.priorityStream); + } + } +} + +void MetricBuffer::receive( + const std::map &scalarMetrics, + const std::map &tensorMetrics, + void *tensorMetricKernel, void *scalarMetricKernel, void *stream) { + queueMetrics(tensorMetrics, tensorMetricKernel, stream); + queueMetrics(scalarMetrics, scalarMetricKernel, stream); +} + +MetricBuffer::MetricDescriptor +MetricBuffer::getOrCreateMetricDescriptor(const std::string &name, + size_t typeIndex) { + { + std::shared_lock lock(metricDescriptorMutex); + auto nameIt = metricNameToId.find(name); + if (nameIt != metricNameToId.end()) { + auto &descriptor = metricDescriptors.at(nameIt->second); + if (descriptor.typeIndex != typeIndex) { + throw std::runtime_error( + "[PROTON] MetricBuffer: type mismatch for metric " + name + + ": current=" + getTypeNameForIndex(descriptor.typeIndex) + + ", new=" + getTypeNameForIndex(typeIndex)); + } + return descriptor; + } + } + + std::unique_lock lock(metricDescriptorMutex); + // Check again in case another thread inserted while we were upgrading the + // lock + auto nameIt = metricNameToId.find(name); + if (nameIt != metricNameToId.end()) { + auto &descriptor = metricDescriptors.at(nameIt->second); + if (descriptor.typeIndex != typeIndex) { + throw std::runtime_error( + "[PROTON] MetricBuffer: type mismatch for metric " + name + + ": current=" + getTypeNameForIndex(descriptor.typeIndex) + + ", new=" + getTypeNameForIndex(typeIndex)); + } + return descriptor; + } + + auto newMetricId = metricId.fetch_add(1); + MetricDescriptor descriptor{newMetricId, typeIndex, name}; + metricDescriptors.emplace(newMetricId, descriptor); + metricNameToId.emplace(name, newMetricId); + return descriptor; +} + +std::map +collectTensorMetrics(Runtime *runtime, + const std::map &tensorMetrics, + void *stream) { + std::map tensorMetricsHost; + for (auto &[name, tensorMetric] : tensorMetrics) { + uint64_t metricBits = 0; + runtime->copyDeviceToHostAsync(&metricBits, tensorMetric.ptr, + sizeof(uint64_t), stream); + runtime->synchronizeStream(stream); + if (tensorMetric.index == variant_index_v) { + double value = 0.0; + std::memcpy(&value, &metricBits, sizeof(value)); + tensorMetricsHost[name] = value; + } else if (tensorMetric.index == + variant_index_v) { + int64_t value = 0; + std::memcpy(&value, &metricBits, sizeof(value)); + tensorMetricsHost[name] = value; + } + } + return tensorMetricsHost; +} + +void MetricBuffer::queue(size_t metricId, TensorMetric tensorMetric, + void *kernel, void *stream) { + auto &buffer = getOrCreateBuffer(); + uint64_t size = capacity / sizeof(uint64_t); + void *globalScratchPtr = nullptr; + void *profileScratchPtr = nullptr; + void *kernelParams[] = {reinterpret_cast(&buffer.devicePtr), + reinterpret_cast(&buffer.deviceOffsetPtr), + reinterpret_cast(&size), + reinterpret_cast(&metricId), + reinterpret_cast(&tensorMetric.ptr), + reinterpret_cast(&globalScratchPtr), + reinterpret_cast(&profileScratchPtr)}; + runtime->launchKernel(kernel, 1, 1, 1, 32, 1, 1, 0, stream, kernelParams, + nullptr); +} + +void MetricBuffer::queue(size_t metricId, MetricValueType scalarMetric, + void *kernel, void *stream) { + auto &buffer = getOrCreateBuffer(); + uint64_t size = capacity / sizeof(uint64_t); + uint64_t metricBits = std::visit( + [](auto &&value) -> uint64_t { + using T = std::decay_t; + if constexpr (std::is_same_v) { + throw std::runtime_error( + "[PROTON] String metrics are not supported in MetricBuffer"); + } else { + static_assert(sizeof(T) == sizeof(uint64_t), + "MetricValueType alternative must be 8 bytes"); + uint64_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; + } + }, + scalarMetric); + void *globalScratchPtr = nullptr; + void *profileScratchPtr = nullptr; + void *kernelParams[] = {reinterpret_cast(&buffer.devicePtr), + reinterpret_cast(&buffer.deviceOffsetPtr), + reinterpret_cast(&size), + reinterpret_cast(&metricId), + reinterpret_cast(&metricBits), + reinterpret_cast(&globalScratchPtr), + reinterpret_cast(&profileScratchPtr)}; + runtime->launchKernel(kernel, 1, 1, 1, 32, 1, 1, 0, stream, kernelParams, + nullptr); +} + +void MetricBuffer::synchronize(DeviceBuffer &buffer) { + runtime->synchronizeDevice(); + if (mappedHostBuffer) { + // Buffer lives in mapped host memory; avoid treating mapped pointers as + // device allocations (e.g. cuMemcpyDtoH / cuMemset) which can error. + return; + } + runtime->copyDeviceToHostAsync(buffer.hostPtr, buffer.devicePtr, capacity, + buffer.priorityStream); + runtime->copyDeviceToHostAsync(buffer.hostOffset, buffer.deviceOffsetPtr, + sizeof(uint64_t), buffer.priorityStream); + runtime->memset(buffer.deviceOffsetPtr, 0, sizeof(uint64_t), + buffer.priorityStream); + runtime->synchronizeStream(buffer.priorityStream); // Ensure memset is done +} + +MetricBuffer::DeviceBuffer &MetricBuffer::getOrCreateBuffer() { + std::lock_guard lock(bufferMutex); + auto device = runtime->getDevice(); + if (deviceBuffers.find(device) == deviceBuffers.end()) { + deviceBuffers[device] = DeviceBuffer{}; + auto &buffer = deviceBuffers.at(device); + if (mappedHostBuffer) { + runtime->allocateHostBuffer(&buffer.hostPtr, capacity, /*mapped=*/true); + runtime->getHostDevicePointer(buffer.hostPtr, &buffer.devicePtr); + runtime->allocateHostBuffer( + reinterpret_cast(&buffer.hostOffset), sizeof(uint64_t), + /*mapped=*/true); + runtime->getHostDevicePointer( + reinterpret_cast(buffer.hostOffset), + &buffer.deviceOffsetPtr); + *buffer.hostOffset = 0; + } else { + runtime->allocateDeviceBuffer(&buffer.devicePtr, capacity); + runtime->allocateDeviceBuffer(&buffer.deviceOffsetPtr, sizeof(uint64_t)); + runtime->allocateHostBuffer(&buffer.hostPtr, capacity, /*mapped=*/false); + runtime->allocateHostBuffer( + reinterpret_cast(&buffer.hostOffset), sizeof(uint64_t), + /*mapped=*/false); + buffer.priorityStream = runtime->getPriorityStream(); + runtime->memset(buffer.deviceOffsetPtr, 0, sizeof(uint64_t), + buffer.priorityStream); + runtime->synchronizeStream(buffer.priorityStream); + } + } + return deviceBuffers.at(device); +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Data/TraceData.cpp b/third_party/mthreads/proton/csrc/lib/Data/TraceData.cpp new file mode 100644 index 0000000000..82c9034b9a --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Data/TraceData.cpp @@ -0,0 +1,515 @@ +#include "Data/TraceData.h" +#include "TraceDataIO/TraceWriter.h" +#include "Utility/MsgPackWriter.h" +#include "nlohmann/json.hpp" + +#include +#include +#include +#include + +using json = nlohmann::json; + +namespace proton { + +class TraceData::Trace { +public: + struct TraceContext : public Context { + inline static const size_t RootId = 0; + inline static const size_t DummyId = std::numeric_limits::max(); + + TraceContext() = default; + explicit TraceContext(size_t id, const std::string &name) + : id(id), Context(name) {} + TraceContext(size_t id, size_t parentId, const std::string &name) + : id(id), parentId(parentId), Context(name) {} + virtual ~TraceContext() = default; + + void addChild(const Context &context, size_t id) { children[context] = id; } + + bool hasChild(const Context &context) const { + return children.find(context) != children.end(); + } + + size_t getChild(const Context &context) const { + return children.at(context); + } + + size_t getParent() const { return parentId; } + + size_t parentId = DummyId; + size_t id = DummyId; + std::map children = {}; + friend class Trace; + }; + + struct TraceEvent { + TraceEvent() = default; + TraceEvent(size_t id, size_t contextId) : id(id), contextId(contextId) {} + size_t id = 0; + size_t scopeId = Scope::DummyScopeId; + size_t contextId = TraceContext::DummyId; + std::map> metrics = {}; + std::map flexibleMetrics = {}; + + const static inline size_t DummyId = std::numeric_limits::max(); + }; + + Trace() { + traceContextMap.try_emplace(TraceContext::RootId, TraceContext::RootId, + "ROOT"); + } + + size_t addContext(const Context &context, size_t parentId) { + if (traceContextMap[parentId].hasChild(context)) { + return traceContextMap[parentId].getChild(context); + } + auto id = nextTreeContextId++; + traceContextMap.try_emplace(id, id, parentId, context.name); + traceContextMap[parentId].addChild(context, id); + return id; + } + + size_t addContexts(const std::vector &contexts, size_t parentId) { + for (const auto &context : contexts) { + parentId = addContext(context, parentId); + } + return parentId; + } + + size_t addContexts(const std::vector &indices) { + auto parentId = TraceContext::RootId; + for (auto index : indices) { + parentId = addContext(index, parentId); + } + return parentId; + } + + std::vector getContexts(size_t contextId) { + std::vector contexts; + auto it = traceContextMap.find(contextId); + if (it == traceContextMap.end()) { + throw std::runtime_error("Context not found"); + } + std::reference_wrapper context = it->second; + contexts.push_back(context.get()); + while (context.get().parentId != TraceContext::DummyId) { + context = traceContextMap[context.get().parentId]; + contexts.push_back(context.get()); + } + std::reverse(contexts.begin(), contexts.end()); + return contexts; + } + + size_t addEvent(size_t contextId) { + traceEvents.emplace(nextEventId, TraceEvent(nextEventId, contextId)); + return nextEventId++; + } + + bool hasEvent(size_t eventId) { + return traceEvents.find(eventId) != traceEvents.end(); + } + + TraceEvent &getEvent(size_t eventId) { + auto it = traceEvents.find(eventId); + if (it == traceEvents.end()) { + throw std::runtime_error("Event not found"); + } + return it->second; + } + + void removeEvent(size_t eventId) { traceEvents.erase(eventId); } + + const std::map &getEvents() const { return traceEvents; } + +private: + size_t nextTreeContextId = TraceContext::RootId + 1; + size_t nextEventId = 0; + std::map traceEvents; + // tree node id -> trace context + std::map traceContextMap; +}; + +void TraceData::enterScope(const Scope &scope) { + // enterOp and addMetric maybe called from different threads + std::unique_lock lock(mutex); + auto *currentTrace = currentPhasePtrAs(); + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + else + contexts.push_back(scope.name); + auto eventId = currentTrace->addEvent(currentTrace->addContexts(contexts)); + scopeIdToEventId[scope.scopeId] = eventId; +} + +void TraceData::exitScope(const Scope &scope) { + scopeIdToEventId.erase(scope.scopeId); +} + +DataEntry TraceData::addOp(const std::string &name) { + std::unique_lock lock(mutex); + auto *currentTrace = currentPhasePtrAs(); + std::vector contexts; + contexts = contextSource->getContexts(); + if (!name.empty()) // not a placeholder event + contexts.emplace_back(name); + auto contextId = currentTrace->addContexts(contexts); + auto eventId = currentTrace->addEvent(contextId); + auto &event = currentTrace->getEvent(eventId); + return DataEntry(eventId, currentPhase.load(std::memory_order_relaxed), + event.metrics); +} + +DataEntry TraceData::addOp(size_t phase, size_t eventId, + const std::vector &contexts) { + auto lock = lockIfCurrentPhase(phase); + auto *trace = phasePtrAs(phase); + // Add a new context under it and update the context + auto &event = trace->getEvent(eventId); + auto contextId = trace->addContexts(contexts, event.contextId); + auto newEventId = trace->addEvent(contextId); + auto &newEvent = trace->getEvent(newEventId); + return DataEntry(newEventId, phase, newEvent.metrics); +} + +void TraceData::addMetrics( + size_t phase, size_t eventId, + const std::map &metrics) { + auto lock = lockIfCurrentPhase(phase); + auto *trace = phasePtrAs(phase); + auto &event = trace->getEvent(eventId); + for (auto [metricName, metricValue] : metrics) { + if (event.flexibleMetrics.find(metricName) == event.flexibleMetrics.end()) { + event.flexibleMetrics.emplace(metricName, + FlexibleMetric(metricName, metricValue)); + } else { + event.flexibleMetrics.at(metricName).updateValue(metricValue); + } + } +} + +void TraceData::addMetrics( + size_t scopeId, const std::map &metrics) { + std::unique_lock lock(mutex); + auto *currentTrace = currentPhasePtrAs(); + auto eventId = scopeIdToEventId.at(scopeId); + auto &event = currentTrace->getEvent(eventId); + for (auto [metricName, metricValue] : metrics) { + if (event.flexibleMetrics.find(metricName) == event.flexibleMetrics.end()) { + event.flexibleMetrics.emplace(metricName, + FlexibleMetric(metricName, metricValue)); + } else { + event.flexibleMetrics.at(metricName).updateValue(metricValue); + } + } +} + +std::string TraceData::toJsonString(size_t phase) const { + std::ostringstream os; + dumpChromeTrace(os, phase); + return os.str(); +} + +std::vector TraceData::toMsgPack(size_t phase) const { + std::ostringstream os; + dumpChromeTrace(os, phase); + MsgPackWriter writer; + writer.packStr(os.str()); + return std::move(writer).take(); +} + +namespace { + +// Structure to pair CycleMetric with its context for processing +struct CycleMetricWithContext { + const CycleMetric *cycleMetric; + uint32_t contextId; + + CycleMetricWithContext(const CycleMetric *metric, uint32_t ctx) + : cycleMetric(metric), contextId(ctx) {} +}; + +std::vector +convertToTimelineTrace(TraceData::Trace *trace, + std::vector &cycleEvents) { + std::vector results; + + auto getInt64Value = [](const CycleMetric *metric, + CycleMetric::CycleMetricKind kind) { + return std::get(metric->getValue(kind)); + }; + + auto getStringValue = [](const CycleMetric *metric, + CycleMetric::CycleMetricKind kind) { + return std::get(metric->getValue(kind)); + }; + + auto getKernelId = [&](const CycleMetricWithContext &event) { + return getInt64Value(event.cycleMetric, CycleMetric::KernelId); + }; + + auto getBlockId = [&](const CycleMetricWithContext &event) { + return getInt64Value(event.cycleMetric, CycleMetric::BlockId); + }; + + auto getUnitId = [&](const CycleMetricWithContext &event) { + return getInt64Value(event.cycleMetric, CycleMetric::UnitId); + }; + + auto getStartCycle = [&](const CycleMetricWithContext &event) { + return getInt64Value(event.cycleMetric, CycleMetric::StartCycle); + }; + + auto getEndCycle = [&](const CycleMetricWithContext &event) { + return getInt64Value(event.cycleMetric, CycleMetric::EndCycle); + }; + + // Pre-sort all events once + auto &sortedEvents = cycleEvents; + std::sort( + sortedEvents.begin(), sortedEvents.end(), + [&](const CycleMetricWithContext &a, const CycleMetricWithContext &b) { + auto aKernelId = getKernelId(a); + auto bKernelId = getKernelId(b); + if (aKernelId != bKernelId) + return aKernelId < bKernelId; + + auto aBlockId = getBlockId(a); + auto bBlockId = getBlockId(b); + if (aBlockId != bBlockId) + return aBlockId < bBlockId; + + auto aUnitId = getUnitId(a); + auto bUnitId = getUnitId(b); + if (aUnitId != bUnitId) + return aUnitId < bUnitId; + + auto aStartCycle = getStartCycle(a); + auto bStartCycle = getStartCycle(b); + return aStartCycle < bStartCycle; + }); + + size_t eventIndex = 0; + + // Process in perfectly sorted order + while (eventIndex < sortedEvents.size()) { + auto kernelEvent = sortedEvents[eventIndex]; + auto currentKernelId = getKernelId(kernelEvent); + + auto parserResult = std::make_shared(); + auto metadata = std::make_shared(); + std::map scopeIdToName; + std::map scopeNameToId; + int curScopeId = 0; + int64_t timeShiftCost = + getInt64Value(kernelEvent.cycleMetric, CycleMetric::TimeShiftCost); + + // Process all events for current kernel + while (eventIndex < sortedEvents.size() && + getKernelId(sortedEvents[eventIndex]) == currentKernelId) { + + const auto &blockEvent = sortedEvents[eventIndex]; + uint32_t currentBlockId = getBlockId(blockEvent); + uint32_t currentProcId = + getInt64Value(blockEvent.cycleMetric, CycleMetric::ProcessorId); + + CircularLayoutParserResult::BlockTrace blockTrace; + blockTrace.blockId = currentBlockId; + blockTrace.procId = currentProcId; + blockTrace.initTime = + getInt64Value(blockEvent.cycleMetric, CycleMetric::InitTime); + blockTrace.preFinalTime = + getInt64Value(blockEvent.cycleMetric, CycleMetric::PreFinalTime); + blockTrace.postFinalTime = + getInt64Value(blockEvent.cycleMetric, CycleMetric::PostFinalTime); + // Conservative estimation of the number of warps in a CTA. + blockTrace.traces.reserve(16); + + // Process all events for current block-proc + while (eventIndex < sortedEvents.size()) { + const auto ¤tEvent = sortedEvents[eventIndex]; + if (getKernelId(currentEvent) != currentKernelId || + getBlockId(currentEvent) != currentBlockId) { + break; + } + + const auto &uintEvent = sortedEvents[eventIndex]; + uint32_t currentUid = getUnitId(uintEvent); + + CircularLayoutParserResult::Trace unitTrace; + unitTrace.uid = currentUid; + // Estimation the number of events in a unit (warp). + unitTrace.profileEvents.reserve(256); + + // Process all events for current uid + while (eventIndex < sortedEvents.size()) { + const auto &event = sortedEvents[eventIndex]; + if (getKernelId(event) != currentKernelId || + getBlockId(event) != currentBlockId || + getUnitId(event) != currentUid) { + break; + } + + auto scopeName = trace->getContexts(event.contextId).back().name; + if (scopeNameToId.count(scopeName) == 0) { + scopeIdToName[curScopeId] = scopeName; + scopeNameToId[scopeName] = curScopeId; + curScopeId++; + } + + auto startEntry = std::make_shared(); + startEntry->cycle = getStartCycle(event); + startEntry->isStart = true; + startEntry->scopeId = scopeNameToId[scopeName]; + + auto endEntry = std::make_shared(); + endEntry->cycle = getEndCycle(event); + endEntry->isStart = false; + endEntry->scopeId = scopeNameToId[scopeName]; + + unitTrace.profileEvents.emplace_back(startEntry, endEntry); + + eventIndex++; + } + blockTrace.traces.push_back(std::move(unitTrace)); + } + parserResult->blockTraces.push_back(std::move(blockTrace)); + } + std::vector callStack; + if (!sortedEvents.empty()) { + auto contexts = trace->getContexts(kernelEvent.contextId); + if (!contexts.empty()) { + callStack.resize(contexts.size() - 1); + std::transform(contexts.begin(), contexts.end() - 1, callStack.begin(), + [](const Context &c) { return c.name; }); + } + } + metadata->kernelName = + getStringValue(kernelEvent.cycleMetric, CycleMetric::KernelName); + metadata->scopeName = scopeIdToName; + metadata->callStack = std::move(callStack); + if (timeShiftCost > 0) + timeShift(timeShiftCost, parserResult); + results.emplace_back(parserResult, metadata); + } + return results; +} + +void dumpCycleMetricTrace(TraceData::Trace *trace, + std::vector &cycleEvents, + std::ostream &os) { + auto timeline = convertToTimelineTrace(trace, cycleEvents); + auto writer = StreamChromeTraceWriter(timeline, ""); + writer.write(os); +} + +void dumpKernelMetricTrace( + TraceData::Trace *trace, uint64_t minTimeStamp, + std::map> + &streamTraceEvents, + std::ostream &os) { + // for each streamId in ascending order, emit one JSON line + for (auto const &[streamId, events] : streamTraceEvents) { + json object = {{"displayTimeUnit", "us"}, {"traceEvents", json::array()}}; + + for (auto const *event : events) { + auto *kernelMetrics = static_cast( + event->metrics.at(MetricKind::Kernel).get()); + uint64_t startTimeNs = + std::get(kernelMetrics->getValue(KernelMetric::StartTime)); + uint64_t endTimeNs = + std::get(kernelMetrics->getValue(KernelMetric::EndTime)); + // Convert nanoseconds to microseconds for Chrome trace format + double ts = static_cast(startTimeNs - minTimeStamp) / 1000; + double dur = static_cast(endTimeNs - startTimeNs) / 1000; + + auto contextId = event->contextId; + auto contexts = trace->getContexts(contextId); + + json element; + element["name"] = contexts.back().name; + element["cat"] = "kernel"; + element["ph"] = "X"; + element["ts"] = ts; + element["dur"] = dur; + element["tid"] = streamId; // thread id = stream + json callStack = json::array(); + for (auto const &ctx : contexts) { + callStack.push_back(ctx.name); + } + element["args"]["call_stack"] = std::move(callStack); + + object["traceEvents"].push_back(element); + } + + // one JSON object per line + os << object.dump() << "\n"; + } +} +} // namespace + +void TraceData::dumpChromeTrace(std::ostream &os, size_t phase) const { + tracePhases.withPtr(phase, [&](Trace *trace) { + auto &events = trace->getEvents(); + // stream id -> trace event + std::map> streamTraceEvents; + uint64_t minTimeStamp = std::numeric_limits::max(); + bool hasKernelMetrics = false, hasCycleMetrics = false; + // Data structure for efficient cycle metrics conversion + std::map kernelBlockNum; + std::vector cycleEvents; + cycleEvents.reserve(events.size()); + for (auto &entry : events) { + auto &event = entry.second; + if (event.metrics.count(MetricKind::Kernel)) { + auto *kernelMetric = static_cast( + event.metrics.at(MetricKind::Kernel).get()); + auto streamId = + std::get(kernelMetric->getValue(KernelMetric::StreamId)); + streamTraceEvents[streamId].push_back(&event); + + uint64_t startTime = + std::get(kernelMetric->getValue(KernelMetric::StartTime)); + minTimeStamp = std::min(minTimeStamp, startTime); + hasKernelMetrics = true; + } + if (event.metrics.count(MetricKind::Cycle)) { + auto *cycleMetric = static_cast( + event.metrics.at(MetricKind::Cycle).get()); + cycleEvents.emplace_back(cycleMetric, event.contextId); + hasCycleMetrics = true; + } + + if (hasKernelMetrics && hasCycleMetrics) { + throw std::runtime_error("only one active metric type is supported"); + } + } + + if (hasCycleMetrics) { + dumpCycleMetricTrace(trace, cycleEvents, os); + } + + if (hasKernelMetrics) { + dumpKernelMetricTrace(trace, minTimeStamp, streamTraceEvents, os); + } + }); +} + +void TraceData::doDump(std::ostream &os, OutputFormat outputFormat, + size_t phase) const { + if (outputFormat == OutputFormat::ChromeTrace) { + dumpChromeTrace(os, phase); + } else { + throw std::logic_error("Output format not supported"); + } +} + +TraceData::TraceData(const std::string &path, ContextSource *contextSource) + : Data(path, contextSource) { + initPhaseStore(tracePhases); +} + +TraceData::~TraceData() {} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Data/TreeData.cpp b/third_party/mthreads/proton/csrc/lib/Data/TreeData.cpp new file mode 100644 index 0000000000..0f508dba8f --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Data/TreeData.cpp @@ -0,0 +1,740 @@ +#include "Data/TreeData.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Device.h" +#include "Utility/MsgPackWriter.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +namespace { + +const std::array(DeviceType::COUNT)> + kDeviceTypeNames = []() { + std::array(DeviceType::COUNT)> names; + for (size_t i = 0; i < static_cast(DeviceType::COUNT); ++i) { + names[i] = getDeviceTypeString(static_cast(i)); + } + return names; + }(); + +constexpr size_t kMaxRegisteredDeviceIds = 32; + +} // namespace + +class TreeData::Tree { +public: + struct TreeNode : public Context { + inline static const size_t RootId = 0; + inline static const size_t DummyId = std::numeric_limits::max(); + + struct ChildEntry { + std::string_view name; + size_t id = DummyId; + }; + + TreeNode() = default; + explicit TreeNode(size_t id, const std::string &name) + : id(id), Context(name) {} + TreeNode(size_t id, size_t parentId, const std::string &name) + : id(id), parentId(parentId), Context(name) {} + virtual ~TreeNode() = default; + + void addChild(std::string_view childName, size_t id) { + children.push_back({childName, id}); + childIndex.emplace(childName, id); + } + + size_t findChild(std::string_view childName) const { + auto it = childIndex.find(childName); + return it != childIndex.end() ? it->second : DummyId; + } + + size_t parentId = DummyId; + size_t id = DummyId; + std::vector children = {}; + std::unordered_map childIndex = {}; + std::map> metrics = {}; + std::map flexibleMetrics = {}; + friend class Tree; + }; + + Tree() { + treeNodeMap.try_emplace(TreeNode::RootId, TreeNode::RootId, + TreeNode::RootId, "ROOT"); + } + + size_t addNode(const std::vector &contexts, size_t parentId) { + for (const auto &context : contexts) { + parentId = addNode(context, parentId); + } + return parentId; + } + + size_t addNode(const Context &context, size_t parentId) { + auto &parent = treeNodeMap.at(parentId); + std::string_view contextName = context.name; + auto existingChildId = parent.findChild(contextName); + if (existingChildId != TreeNode::DummyId) + return existingChildId; + auto id = nextContextId++; + auto [it, inserted] = + treeNodeMap.try_emplace(id, id, parentId, context.name); + parent.addChild(it->second.name, id); + return id; + } + + size_t addNode(const std::vector &indices) { + auto parentId = TreeNode::RootId; + for (auto index : indices) { + parentId = addNode(index, parentId); + } + return parentId; + } + + TreeNode &getNode(size_t id) { return treeNodeMap.at(id); } + + void upsertFlexibleMetric(size_t contextId, + const FlexibleMetric &flexibleMetric) { + auto &node = treeNodeMap.at(contextId); + auto it = node.flexibleMetrics.find(flexibleMetric.getValueName(0)); + if (it == node.flexibleMetrics.end()) { + node.flexibleMetrics.emplace(flexibleMetric.getValueName(0), + flexibleMetric); + } else { + it->second.updateMetric(flexibleMetric); + } + } + + enum class WalkPolicy { PreOrder, PostOrder }; + + template void walk(FnT &&fn) { + if constexpr (walkPolicy == WalkPolicy::PreOrder) { + walkPreOrder(TreeNode::RootId, fn); + } else if constexpr (walkPolicy == WalkPolicy::PostOrder) { + walkPostOrder(TreeNode::RootId, fn); + } + } + + template void walkPreOrder(size_t contextId, FnT &&fn) { + fn(getNode(contextId)); + for (const auto &child : getNode(contextId).children) { + walkPreOrder(child.id, fn); + } + } + + template void walkPostOrder(size_t contextId, FnT &&fn) { + for (const auto &child : getNode(contextId).children) { + walkPostOrder(child.id, fn); + } + fn(getNode(contextId)); + } + + size_t size() const { return nextContextId; } + +private: + size_t nextContextId = TreeNode::RootId + 1; + // tree node id -> tree node + std::unordered_map treeNodeMap; +}; + +json TreeData::buildHatchetJson(TreeData::Tree *tree) const { + std::vector jsonNodes(tree->size(), nullptr); + json output = json::array(); + output.push_back(json::object()); + jsonNodes[TreeData::Tree::TreeNode::RootId] = &(output.back()); + bool hasKernelMetric = false; + bool hasPCSamplingMetric = false; + bool hasCycleMetric = false; + std::array(DeviceType::COUNT)> deviceIdMasks{}; + tree->template walk( + [&](TreeData::Tree::TreeNode &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + auto &metricsJson = (*jsonNode)["metrics"]; + for (auto &[metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + hasKernelMetric = true; + auto *kernelMetric = static_cast(metric.get()); + uint64_t duration = std::get( + kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + if (deviceId < kMaxRegisteredDeviceIds) { + deviceIdMasks[static_cast(deviceType)] |= + (1u << static_cast(deviceId)); + } else { + throw std::runtime_error( + "[PROTON] DeviceId " + std::to_string(deviceId) + + " exceeds MaxRegisteredDeviceIds " + + std::to_string(kMaxRegisteredDeviceIds) + " for deviceType " + + std::to_string(deviceType)); + } + const auto &deviceTypeName = + kDeviceTypeNames[static_cast(deviceType)]; + const auto &durationName = + kernelMetric->getValueName(KernelMetric::Duration); + const auto &invocationsName = + kernelMetric->getValueName(KernelMetric::Invocations); + const auto &deviceIdName = + kernelMetric->getValueName(KernelMetric::DeviceId); + const auto &deviceTypeNameKey = + kernelMetric->getValueName(KernelMetric::DeviceType); + const auto deviceIdStr = std::to_string(deviceId); + + metricsJson[durationName] = duration; + metricsJson[invocationsName] = invocations; + metricsJson[deviceIdName] = deviceIdStr; + metricsJson[deviceTypeNameKey] = deviceTypeName; + } else if (metricKind == MetricKind::PCSampling) { + hasPCSamplingMetric = true; + auto *pcSamplingMetric = + static_cast(metric.get()); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + const auto &valueName = pcSamplingMetric->getValueName(i); + std::visit([&](auto &&value) { metricsJson[valueName] = value; }, + pcSamplingMetric->getValues()[i]); + } + } else if (metricKind == MetricKind::Cycle) { + hasCycleMetric = true; + auto *cycleMetric = static_cast(metric.get()); + uint64_t duration = std::get( + cycleMetric->getValue(CycleMetric::Duration)); + double normalizedDuration = std::get( + cycleMetric->getValue(CycleMetric::NormalizedDuration)); + uint64_t deviceId = std::get( + cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = std::get( + cycleMetric->getValue(CycleMetric::DeviceType)); + if (deviceId < kMaxRegisteredDeviceIds) { + deviceIdMasks[static_cast(deviceType)] |= + (1u << static_cast(deviceId)); + } else { + throw std::runtime_error( + "[PROTON] DeviceId " + std::to_string(deviceId) + + " exceeds MaxRegisteredDeviceIds " + + std::to_string(kMaxRegisteredDeviceIds) + " for deviceType " + + std::to_string(deviceType)); + } + const auto &durationName = + cycleMetric->getValueName(CycleMetric::Duration); + const auto &normalizedDurationName = + cycleMetric->getValueName(CycleMetric::NormalizedDuration); + const auto &deviceIdName = + cycleMetric->getValueName(CycleMetric::DeviceId); + const auto &deviceTypeName = + cycleMetric->getValueName(CycleMetric::DeviceType); + const auto deviceIdStr = std::to_string(deviceId); + const auto deviceTypeStr = std::to_string(deviceType); + + metricsJson[durationName] = duration; + metricsJson[normalizedDurationName] = normalizedDuration; + metricsJson[deviceIdName] = deviceIdStr; + metricsJson[deviceTypeName] = deviceTypeStr; + } else if (metricKind == MetricKind::Flexible) { + // Flexible metrics are handled in a different way + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto &[_, flexibleMetric] : treeNode.flexibleMetrics) { + const auto &valueName = flexibleMetric.getValueName(0); + std::visit([&](auto &&value) { metricsJson[valueName] = value; }, + flexibleMetric.getValues()[0]); + } + auto &childrenArray = (*jsonNode)["children"]; + childrenArray = json::array(); + childrenArray.get_ref().reserve( + treeNode.children.size()); + for (const auto &child : treeNode.children) { + childrenArray.push_back(json::object()); + jsonNodes[child.id] = &childrenArray.back(); + } + }); + + if (hasKernelMetric) { + KernelMetric kernelMetric; + output[TreeData::Tree::TreeNode::RootId]["metrics"] + [kernelMetric.getValueName(KernelMetric::Invocations)] = 0; + output[TreeData::Tree::TreeNode::RootId]["metrics"] + [kernelMetric.getValueName(KernelMetric::Duration)] = 0; + } + if (hasCycleMetric) { + CycleMetric cycleMetric; + output[TreeData::Tree::TreeNode::RootId]["metrics"] + [cycleMetric.getValueName(CycleMetric::Duration)] = 0; + output[TreeData::Tree::TreeNode::RootId]["metrics"] + [cycleMetric.getValueName(CycleMetric::NormalizedDuration)] = 0; + } + if (hasPCSamplingMetric) { + PCSamplingMetric pcSamplingMetric; + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + const auto &valueName = pcSamplingMetric.getValueName(i); + output[TreeData::Tree::TreeNode::RootId]["metrics"][valueName] = 0; + } + } + + output.push_back(json::object()); + auto &deviceJson = output.back(); + for (size_t deviceType = 0; + deviceType < static_cast(DeviceType::COUNT); ++deviceType) { + auto mask = deviceIdMasks[deviceType]; + if (mask == 0) { + continue; + } + + const auto &deviceTypeName = kDeviceTypeNames[deviceType]; + deviceJson[deviceTypeName] = json::object(); + + for (uint64_t deviceId = 0; deviceId < kMaxRegisteredDeviceIds; + ++deviceId) { + if ((mask & (1u << static_cast(deviceId))) == 0) { + continue; + } + Device device = getDevice(static_cast(deviceType), deviceId); + deviceJson[deviceTypeName][std::to_string(deviceId)] = { + {"clock_rate", device.clockRate}, + {"memory_clock_rate", device.memoryClockRate}, + {"bus_width", device.busWidth}, + {"arch", device.arch}, + {"num_sms", device.numSms}}; + } + + if (deviceJson[deviceTypeName].empty()) { + deviceJson.erase(deviceTypeName); + } + } + return output; +} + +std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { + MsgPackWriter writer; + writer.reserve(16 * 1024 * 1024); // 16 MB + + bool hasKernelMetric = false; + bool hasPCSamplingMetric = false; + bool hasCycleMetric = false; + std::array(DeviceType::COUNT)> deviceIdMasks{}; + + auto updateDeviceIdMask = [&](uint64_t deviceType, uint64_t deviceId) { + if (deviceId < kMaxRegisteredDeviceIds) { + deviceIdMasks[static_cast(deviceType)] |= + (1u << static_cast(deviceId)); + } else { + throw std::runtime_error("[PROTON] DeviceId " + std::to_string(deviceId) + + " exceeds MaxRegisteredDeviceIds " + + std::to_string(kMaxRegisteredDeviceIds) + + " for deviceType " + std::to_string(deviceType)); + } + }; + + tree->template walk( + [&](TreeData::Tree::TreeNode &treeNode) { + for (auto &[metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + hasKernelMetric = true; + auto *kernelMetric = static_cast(metric.get()); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + updateDeviceIdMask(deviceType, deviceId); + } else if (metricKind == MetricKind::PCSampling) { + hasPCSamplingMetric = true; + } else if (metricKind == MetricKind::Cycle) { + hasCycleMetric = true; + auto *cycleMetric = static_cast(metric.get()); + uint64_t deviceId = std::get( + cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = std::get( + cycleMetric->getValue(CycleMetric::DeviceType)); + updateDeviceIdMask(deviceType, deviceId); + } + } + }); + + // We only need these metrics for tree data + KernelMetric kernelMetric; + auto &kernelMetricDurationName = + kernelMetric.getValueName(KernelMetric::Duration); + auto &kernelMetricInvocationsName = + kernelMetric.getValueName(KernelMetric::Invocations); + auto &kernelMetricDeviceIdName = + kernelMetric.getValueName(KernelMetric::DeviceId); + auto &kernelMetricDeviceTypeName = + kernelMetric.getValueName(KernelMetric::DeviceType); + CycleMetric cycleMetric; + auto &cycleMetricDurationName = + cycleMetric.getValueName(CycleMetric::Duration); + auto &cycleMetricNormalizedDurationName = + cycleMetric.getValueName(CycleMetric::NormalizedDuration); + auto &cycleMetricDeviceIdName = + cycleMetric.getValueName(CycleMetric::DeviceId); + auto &cycleMetricDeviceTypeName = + cycleMetric.getValueName(CycleMetric::DeviceType); + std::set kernelInclusiveValueNames = { + kernelMetricDurationName, kernelMetricInvocationsName}; + std::set kernelExclusiveValueNames = { + kernelMetricDeviceIdName, kernelMetricDeviceTypeName}; + std::set cycleInclusiveValueNames = { + cycleMetricDurationName, cycleMetricNormalizedDurationName}; + std::set cycleExclusiveValueNames = {cycleMetricDeviceIdName, + cycleMetricDeviceTypeName}; + std::function packNode = + [&](TreeData::Tree::TreeNode &treeNode) { + writer.packMap(3); + + writer.packStr("frame"); + writer.packMap(2); + writer.packStr("name"); + writer.packStr(treeNode.name); + writer.packStr("type"); + writer.packStr("function"); + + writer.packStr("metrics"); + uint32_t metricEntries = 0; + for (auto &[metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + metricEntries += (treeNode.id == TreeData::Tree::TreeNode::RootId) + ? kernelInclusiveValueNames.size() + : (kernelInclusiveValueNames.size() + + kernelExclusiveValueNames.size()); + } else if (metricKind == MetricKind::PCSampling) { + metricEntries += PCSamplingMetric::Count; + } else if (metricKind == MetricKind::Cycle) { + metricEntries += (treeNode.id == TreeData::Tree::TreeNode::RootId) + ? cycleInclusiveValueNames.size() + : (cycleInclusiveValueNames.size() + + cycleExclusiveValueNames.size()); + } + } + if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + if (hasKernelMetric && treeNode.metrics.find(MetricKind::Kernel) == + treeNode.metrics.end()) { + metricEntries += + static_cast(kernelInclusiveValueNames.size()); + } + if (hasPCSamplingMetric && + treeNode.metrics.find(MetricKind::PCSampling) == + treeNode.metrics.end()) { + metricEntries += PCSamplingMetric::Count; + } + if (hasCycleMetric && treeNode.metrics.find(MetricKind::Cycle) == + treeNode.metrics.end()) { + metricEntries += + static_cast(cycleInclusiveValueNames.size()); + } + } + metricEntries += static_cast(treeNode.flexibleMetrics.size()); + writer.packMap(metricEntries); + + for (auto &[metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + writer.packStr(kernelMetricDurationName); + writer.packUInt(0); + writer.packStr(kernelMetricInvocationsName); + writer.packUInt(0); + continue; + } + + auto *kernelMetric = static_cast(metric.get()); + uint64_t duration = std::get( + kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + const auto &deviceTypeName = + kDeviceTypeNames[static_cast(deviceType)]; + writer.packStr(kernelMetricDurationName); + writer.packUInt(duration); + writer.packStr(kernelMetricInvocationsName); + writer.packUInt(invocations); + writer.packStr(kernelMetricDeviceIdName); + writer.packStr(std::to_string(deviceId)); + writer.packStr(kernelMetricDeviceTypeName); + writer.packStr(deviceTypeName); + } else if (metricKind == MetricKind::PCSampling) { + auto *pcSamplingMetric = + static_cast(metric.get()); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + const auto &valueName = pcSamplingMetric->getValueName(i); + writer.packStr(valueName); + if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + writer.packUInt(0); + } else { + writer.packUInt( + std::get(pcSamplingMetric->getValues()[i])); + } + } + } else if (metricKind == MetricKind::Cycle) { + if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + writer.packStr(cycleMetricDurationName); + writer.packUInt(0); + writer.packStr(cycleMetricNormalizedDurationName); + writer.packUInt(0); + continue; + } + + auto *cycleMetric = static_cast(metric.get()); + uint64_t duration = std::get( + cycleMetric->getValue(CycleMetric::Duration)); + double normalizedDuration = std::get( + cycleMetric->getValue(CycleMetric::NormalizedDuration)); + uint64_t deviceId = std::get( + cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = std::get( + cycleMetric->getValue(CycleMetric::DeviceType)); + + writer.packStr(cycleMetricDurationName); + writer.packUInt(duration); + writer.packStr(cycleMetricNormalizedDurationName); + writer.packDouble(normalizedDuration); + writer.packStr(cycleMetricDeviceIdName); + writer.packStr(std::to_string(deviceId)); + writer.packStr(cycleMetricDeviceTypeName); + writer.packStr(std::to_string(deviceType)); + } + } + + for (auto &[_, flexibleMetric] : treeNode.flexibleMetrics) { + const auto &valueName = flexibleMetric.getValueName(0); + writer.packStr(valueName); + std::visit( + [&](auto &&v) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + writer.packUInt(v); + } else if constexpr (std::is_same_v) { + writer.packInt(v); + } else if constexpr (std::is_same_v) { + writer.packDouble(v); + } else if constexpr (std::is_same_v) { + writer.packStr(v); + } else { + static_assert(sizeof(T) == 0, "Unsupported MetricValueType"); + } + }, + flexibleMetric.getValues()[0]); + } + + if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + if (hasKernelMetric && treeNode.metrics.find(MetricKind::Kernel) == + treeNode.metrics.end()) { + writer.packStr(kernelMetricDurationName); + writer.packUInt(0); + writer.packStr(kernelMetricInvocationsName); + writer.packUInt(0); + } + if (hasPCSamplingMetric && + treeNode.metrics.find(MetricKind::PCSampling) == + treeNode.metrics.end()) { + PCSamplingMetric pcSamplingMetric; + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + const auto &valueName = pcSamplingMetric.getValueName(i); + writer.packStr(valueName); + writer.packUInt(0); + } + } + if (hasCycleMetric && treeNode.metrics.find(MetricKind::Cycle) == + treeNode.metrics.end()) { + writer.packStr(cycleMetricDurationName); + writer.packUInt(0); + writer.packStr(cycleMetricNormalizedDurationName); + writer.packUInt(0); + } + } + + writer.packStr("children"); + writer.packArray(static_cast(treeNode.children.size())); + for (const auto &child : treeNode.children) { + packNode(tree->getNode(child.id)); + } + }; + + uint32_t deviceTypeEntries = 0; + for (size_t deviceType = 0; + deviceType < static_cast(DeviceType::COUNT); ++deviceType) { + if (deviceIdMasks[deviceType] != 0) { + ++deviceTypeEntries; + } + } + // Hatchet format: [tree, device_metadata]. Always emit 2 elements to match + // the JSON serializer, even if device_metadata is empty. + writer.packArray(2); + packNode(tree->getNode(TreeData::Tree::TreeNode::RootId)); + + auto countSetBits = [](uint32_t mask) -> uint32_t { + uint32_t count = 0; + while (mask) { + mask &= (mask - 1); + ++count; + } + return count; + }; + + writer.packMap(deviceTypeEntries); + for (size_t deviceType = 0; + deviceType < static_cast(DeviceType::COUNT); ++deviceType) { + auto mask = deviceIdMasks[deviceType]; + if (mask == 0) { + continue; + } + + const auto &deviceTypeName = kDeviceTypeNames[deviceType]; + writer.packStr(deviceTypeName); + + writer.packMap(countSetBits(mask)); + for (uint64_t deviceId = 0; deviceId < kMaxRegisteredDeviceIds; + ++deviceId) { + if ((mask & (1u << static_cast(deviceId))) == 0) { + continue; + } + Device device = getDevice(static_cast(deviceType), deviceId); + writer.packStr(std::to_string(deviceId)); + writer.packMap(5); + writer.packStr("clock_rate"); + writer.packUInt(device.clockRate); + writer.packStr("memory_clock_rate"); + writer.packUInt(device.memoryClockRate); + writer.packStr("bus_width"); + writer.packUInt(device.busWidth); + writer.packStr("arch"); + writer.packStr(device.arch); + writer.packStr("num_sms"); + writer.packUInt(device.numSms); + } + } + + return std::move(writer).take(); +} + +void TreeData::enterScope(const Scope &scope) { + // enterOp and addMetric maybe called from different threads + std::unique_lock lock(mutex); + auto *currentTree = currentPhasePtrAs(); + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + else + contexts.push_back(scope.name); + auto contextId = currentTree->addNode(contexts); + scopeIdToContextId[scope.scopeId] = contextId; +} + +void TreeData::exitScope(const Scope &scope) { + std::unique_lock lock(mutex); + scopeIdToContextId.erase(scope.scopeId); +} + +DataEntry TreeData::addOp(const std::string &name) { + std::unique_lock lock(mutex); + auto *currentTree = currentPhasePtrAs(); + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + if (!name.empty()) + contexts.emplace_back(name); + auto contextId = currentTree->addNode(contexts); + auto &node = currentTree->getNode(contextId); + return DataEntry(contextId, currentPhase.load(std::memory_order_relaxed), + node.metrics); +} + +DataEntry TreeData::addOp(size_t phase, size_t contextId, + const std::vector &contexts) { + auto lock = lockIfCurrentPhase(phase); + auto *tree = phasePtrAs(phase); + auto newContextId = tree->addNode(contexts, contextId); + auto &node = tree->getNode(newContextId); + return DataEntry(newContextId, phase, node.metrics); +} + +void TreeData::addMetrics( + size_t scopeId, const std::map &metrics) { + std::unique_lock lock(mutex); + auto *currentTree = currentPhasePtrAs(); + auto contextId = scopeIdToContextId.at(scopeId); + for (auto [metricName, metricValue] : metrics) { + currentTree->upsertFlexibleMetric(contextId, + FlexibleMetric(metricName, metricValue)); + } +} + +void TreeData::addMetrics( + size_t phase, size_t contextId, + const std::map &metrics) { + auto lock = lockIfCurrentPhase(phase); + auto *tree = phasePtrAs(phase); + for (auto [metricName, metricValue] : metrics) { + tree->upsertFlexibleMetric(contextId, + FlexibleMetric(metricName, metricValue)); + } +} + +void TreeData::dumpHatchet(std::ostream &os, size_t phase) const { + treePhases.withPtr(phase, [&](Tree *tree) { + auto output = buildHatchetJson(tree); + os << std::endl << output.dump(4) << std::endl; + }); +} + +void TreeData::dumpHatchetMsgPack(std::ostream &os, size_t phase) const { + treePhases.withPtr(phase, [&](Tree *tree) { + auto msgPack = buildHatchetMsgPack(tree); + os.write(reinterpret_cast(msgPack.data()), + static_cast(msgPack.size())); + }); +} + +std::string TreeData::toJsonString(size_t phase) const { + return treePhases.withPtr( + phase, [&](Tree *tree) { return buildHatchetJson(tree).dump(); }); +} + +std::vector TreeData::toMsgPack(size_t phase) const { + return treePhases.withPtr( + phase, [&](Tree *tree) { return buildHatchetMsgPack(tree); }); +} + +void TreeData::doDump(std::ostream &os, OutputFormat outputFormat, + size_t phase) const { + if (outputFormat == OutputFormat::Hatchet) { + dumpHatchet(os, phase); + } else if (outputFormat == OutputFormat::HatchetMsgPack) { + dumpHatchetMsgPack(os, phase); + } else { + throw std::logic_error("Output format not supported"); + } +} + +TreeData::TreeData(const std::string &path, ContextSource *contextSource) + : Data(path, contextSource) { + initPhaseStore(treePhases); +} + +TreeData::~TreeData() {} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Driver/CMakeLists.txt new file mode 100644 index 0000000000..438f24f49e --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/CMakeLists.txt @@ -0,0 +1,9 @@ +add_proton_library(ProtonDriver + Device.cpp + GPU/CudaApi.cpp + GPU/CuptiApi.cpp + GPU/HipApi.cpp + GPU/HsaApi.cpp + GPU/RoctracerApi.cpp + GPU/NvtxApi.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Driver/Device.cpp b/third_party/mthreads/proton/csrc/lib/Driver/Device.cpp new file mode 100644 index 0000000000..24f4d16126 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/Device.cpp @@ -0,0 +1,28 @@ +#include "Device.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/HipApi.h" + +#include "Utility/Errors.h" + +namespace proton { + +Device getDevice(DeviceType type, uint64_t index) { + if (type == DeviceType::CUDA) { + return cuda::getDevice(index); + } + if (type == DeviceType::HIP) { + return hip::getDevice(index); + } + throw std::runtime_error("DeviceType not supported"); +} + +const std::string getDeviceTypeString(DeviceType type) { + if (type == DeviceType::CUDA) { + return DeviceTraits::name; + } else if (type == DeviceType::HIP) { + return DeviceTraits::name; + } + throw std::runtime_error("DeviceType not supported"); +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/CudaApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/CudaApi.cpp new file mode 100644 index 0000000000..f1f57c9b54 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/CudaApi.cpp @@ -0,0 +1,98 @@ +#include "Driver/GPU/CudaApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace cuda { + +struct ExternLibCuda : public ExternLibBase { + using RetType = CUresult; + // https://forums.developer.nvidia.com/t/wsl2-libcuda-so-and-libcuda-so-1-should-be-symlink/236301 + // On WSL, "libcuda.so" and "libcuda.so.1" may not be linked, so we use + // "libcuda.so.1" instead. + static constexpr const char *name = "libcuda.so.1"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = CUDA_SUCCESS; + static void *lib; +}; + +void *ExternLibCuda::lib = nullptr; + +DEFINE_DISPATCH(ExternLibCuda, init, cuInit, int) + +DEFINE_DISPATCH(ExternLibCuda, ctxSynchronize, cuCtxSynchronize) + +DEFINE_DISPATCH(ExternLibCuda, ctxGetCurrent, cuCtxGetCurrent, CUcontext *) + +DEFINE_DISPATCH(ExternLibCuda, ctxGetDevice, cuCtxGetDevice, CUdevice *) + +DEFINE_DISPATCH(ExternLibCuda, ctxGetStreamPriorityRange, + cuCtxGetStreamPriorityRange, int *, int *) + +DEFINE_DISPATCH(ExternLibCuda, deviceGet, cuDeviceGet, CUdevice *, int) + +DEFINE_DISPATCH(ExternLibCuda, deviceGetAttribute, cuDeviceGetAttribute, int *, + CUdevice_attribute, CUdevice) + +DEFINE_DISPATCH(ExternLibCuda, streamCreateWithPriority, + cuStreamCreateWithPriority, CUstream *, unsigned int, int) + +DEFINE_DISPATCH(ExternLibCuda, streamSynchronize, cuStreamSynchronize, CUstream) + +DEFINE_DISPATCH(ExternLibCuda, streamDestroy, cuStreamDestroy, CUstream) + +DEFINE_DISPATCH(ExternLibCuda, memcpyDToHAsync, cuMemcpyDtoHAsync, void *, + CUdeviceptr, size_t, CUstream) + +DEFINE_DISPATCH(ExternLibCuda, memsetD32Async, cuMemsetD32Async, CUdeviceptr, + unsigned int, size_t, CUstream) + +DEFINE_DISPATCH(ExternLibCuda, memAlloc, cuMemAlloc, CUdeviceptr *, size_t) + +DEFINE_DISPATCH(ExternLibCuda, memFree, cuMemFree, CUdeviceptr) + +DEFINE_DISPATCH(ExternLibCuda, memAllocHost, cuMemAllocHost, void **, size_t) + +DEFINE_DISPATCH(ExternLibCuda, memHostAlloc, cuMemHostAlloc, void **, size_t, + unsigned int) + +DEFINE_DISPATCH(ExternLibCuda, memHostGetDevicePointer, + cuMemHostGetDevicePointer, CUdeviceptr *, void *, unsigned int) + +DEFINE_DISPATCH(ExternLibCuda, memFreeHost, cuMemFreeHost, void *) + +DEFINE_DISPATCH(ExternLibCuda, launchKernel, cuLaunchKernel, CUfunction, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **, + void **) + +Device getDevice(uint64_t index) { + CUdevice device; + cuda::deviceGet(&device, index); + int clockRate; + cuda::deviceGetAttribute(&clockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, + device); + int memoryClockRate; + cuda::deviceGetAttribute(&memoryClockRate, + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device); + int busWidth; + cuda::deviceGetAttribute( + &busWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device); + int numSms; + cuda::deviceGetAttribute( + &numSms, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device); + int major; + cuda::deviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); + int minor; + cuda::deviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); + std::string arch = std::to_string(major * 10 + minor); + + return Device(DeviceType::CUDA, index, clockRate, memoryClockRate, busWidth, + numSms, arch); +} + +} // namespace cuda + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/CuptiApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/CuptiApi.cpp new file mode 100644 index 0000000000..8db77e258a --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/CuptiApi.cpp @@ -0,0 +1,111 @@ +#include "Driver/GPU/CuptiApi.h" +#include "Device.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace cupti { + +DEFINE_DISPATCH(ExternLibCupti, getVersion, cuptiGetVersion, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getContextId, cuptiGetContextId, CUcontext, + uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, activityRegisterCallbacks, + cuptiActivityRegisterCallbacks, + CUpti_BuffersCallbackRequestFunc, + CUpti_BuffersCallbackCompleteFunc) + +DEFINE_DISPATCH(ExternLibCupti, subscribe, cuptiSubscribe, + CUpti_SubscriberHandle *, CUpti_CallbackFunc, void *) + +DEFINE_DISPATCH(ExternLibCupti, enableDomain, cuptiEnableDomain, uint32_t, + CUpti_SubscriberHandle, CUpti_CallbackDomain) + +DEFINE_DISPATCH(ExternLibCupti, enableCallback, cuptiEnableCallback, uint32_t, + CUpti_SubscriberHandle, CUpti_CallbackDomain, CUpti_CallbackId); + +DEFINE_DISPATCH(ExternLibCupti, activityEnable, cuptiActivityEnable, + CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityDisable, cuptiActivityDisable, + CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityEnableContext, + cuptiActivityEnableContext, CUcontext, CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityDisableContext, + cuptiActivityDisableContext, CUcontext, CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityFlushAll, cuptiActivityFlushAll, + uint32_t) + +DEFINE_DISPATCH(ExternLibCupti, activityGetNextRecord, + cuptiActivityGetNextRecord, uint8_t *, size_t, + CUpti_Activity **) + +DEFINE_DISPATCH(ExternLibCupti, activityPushExternalCorrelationId, + cuptiActivityPushExternalCorrelationId, + CUpti_ExternalCorrelationKind, uint64_t) + +DEFINE_DISPATCH(ExternLibCupti, activityPopExternalCorrelationId, + cuptiActivityPopExternalCorrelationId, + CUpti_ExternalCorrelationKind, uint64_t *) + +DEFINE_DISPATCH(ExternLibCupti, activitySetAttribute, cuptiActivitySetAttribute, + CUpti_ActivityAttribute, size_t *, void *) + +DEFINE_DISPATCH(ExternLibCupti, activityEnableHWTrace, + cuptiActivityEnableHWTrace, uint8_t) + +DEFINE_DISPATCH(ExternLibCupti, unsubscribe, cuptiUnsubscribe, + CUpti_SubscriberHandle) + +DEFINE_DISPATCH(ExternLibCupti, finalize, cuptiFinalize) + +DEFINE_DISPATCH(ExternLibCupti, getGraphExecId, cuptiGetGraphExecId, + CUgraphExec, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getGraphId, cuptiGetGraphId, CUgraph, + uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getGraphNodeId, cuptiGetGraphNodeId, + CUgraphNode, uint64_t *); + +DEFINE_DISPATCH(ExternLibCupti, getCubinCrc, cuptiGetCubinCrc, + CUpti_GetCubinCrcParams *); + +DEFINE_DISPATCH(ExternLibCupti, getSassToSourceCorrelation, + cuptiGetSassToSourceCorrelation, + CUpti_GetSassToSourceCorrelationParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetNumStallReasons, + cuptiPCSamplingGetNumStallReasons, + CUpti_PCSamplingGetNumStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetStallReasons, + cuptiPCSamplingGetStallReasons, + CUpti_PCSamplingGetStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingSetConfigurationAttribute, + cuptiPCSamplingSetConfigurationAttribute, + CUpti_PCSamplingConfigurationInfoParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingEnable, cuptiPCSamplingEnable, + CUpti_PCSamplingEnableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingDisable, cuptiPCSamplingDisable, + CUpti_PCSamplingDisableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetData, cuptiPCSamplingGetData, + CUpti_PCSamplingGetDataParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStart, cuptiPCSamplingStart, + CUpti_PCSamplingStartParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStop, cuptiPCSamplingStop, + CUpti_PCSamplingStopParams *); + +} // namespace cupti + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/HipApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/HipApi.cpp new file mode 100644 index 0000000000..cc0764c9c3 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/HipApi.cpp @@ -0,0 +1,129 @@ +#include "Driver/GPU/HipApi.h" +#include "Driver/Dispatch.h" +#include "hip/hip_runtime_api.h" +#include + +namespace proton { + +namespace hip { + +struct ExternLibHip : public ExternLibBase { + using RetType = hipError_t; + static constexpr const char *name = "libamdhip64.so"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = hipSuccess; + static void *lib; +}; + +void *ExternLibHip::lib = nullptr; + +DEFINE_DISPATCH(ExternLibHip, launchKernel, hipModuleLaunchKernel, + hipFunction_t, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + hipStream_t, void **, void **) + +DEFINE_DISPATCH(ExternLibHip, deviceSynchronize, hipDeviceSynchronize) + +DEFINE_DISPATCH(ExternLibHip, deviceGetAttribute, hipDeviceGetAttribute, int *, + hipDeviceAttribute_t, int); + +DEFINE_DISPATCH(ExternLibHip, getDeviceCount, hipGetDeviceCount, int *); + +DEFINE_DISPATCH(ExternLibHip, getDeviceProperties, hipGetDeviceProperties, + hipDeviceProp_t *, int); + +DEFINE_DISPATCH(ExternLibHip, memAllocHost, hipMemAllocHost, void **, size_t) + +DEFINE_DISPATCH(ExternLibHip, memHostAlloc, hipHostAlloc, void **, size_t, + unsigned int) + +DEFINE_DISPATCH(ExternLibHip, memFreeHost, hipFreeHost, void *) + +DEFINE_DISPATCH(ExternLibHip, memHostGetDevicePointer, hipHostGetDevicePointer, + hipDeviceptr_t *, void *, unsigned int) + +DEFINE_DISPATCH(ExternLibHip, memAlloc, hipMemAlloc, hipDeviceptr_t *, size_t) + +DEFINE_DISPATCH(ExternLibHip, memFree, hipFree, hipDeviceptr_t) + +DEFINE_DISPATCH(ExternLibHip, memsetD32Async, hipMemsetD32Async, hipDeviceptr_t, + int, size_t, hipStream_t) + +DEFINE_DISPATCH(ExternLibHip, ctxGetDevice, hipCtxGetDevice, hipDevice_t *) + +DEFINE_DISPATCH(ExternLibHip, ctxGetStreamPriorityRange, + hipDeviceGetStreamPriorityRange, int *, int *) + +DEFINE_DISPATCH(ExternLibHip, streamCreateWithPriority, + hipStreamCreateWithPriority, hipStream_t *, unsigned int, int) + +DEFINE_DISPATCH(ExternLibHip, streamSynchronize, hipStreamSynchronize, + hipStream_t) + +DEFINE_DISPATCH(ExternLibHip, streamDestroy, hipStreamDestroy, hipStream_t) + +DEFINE_DISPATCH(ExternLibHip, memcpyDToHAsync, hipMemcpyDtoHAsync, void *, + hipDeviceptr_t, size_t, hipStream_t) + +Device getDevice(uint64_t index) { + int clockRate; + (void)hip::deviceGetAttribute(&clockRate, hipDeviceAttributeClockRate, + index); + int memoryClockRate; + (void)hip::deviceGetAttribute(&memoryClockRate, + hipDeviceAttributeMemoryClockRate, index); + int busWidth; + (void)hip::deviceGetAttribute(&busWidth, + hipDeviceAttributeMemoryBusWidth, index); + int smCount; + (void)hip::deviceGetAttribute( + &smCount, hipDeviceAttributeMultiprocessorCount, index); + + std::string arch = getHipArchName(index); + + return Device(DeviceType::HIP, index, clockRate, memoryClockRate, busWidth, + smCount, arch); +} + +// TODO: hipDeviceProp_t was updated to point from hipDeviceProp_tR0000 -> +// hipDeviceProp_tR0600 as part of a breaking API change in Rocm 6.0 +// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.c +// uses hipDeviceProp_tR0000 and imports the hip_deprecated.h header file to be +// be back compatible with ROCm 5.x. PyTorch stills needs to support 5.x and the +// hipDeviceProp_tR0600 symbol does not exist pre-Rocm 6.0. Calling +// hipDeviceProp_tR0000 here with Rocm 6.1 causes a stack corruption. Therefore +// were will use hipDeviceProp_t and investigate if we can unify the definitions +// in the two files. + +const std::string getHipArchName(uint64_t index) { + hipDeviceProp_t devProp; + (void)hip::getDeviceProperties(&devProp, index); + std::string gcnArchName(devProp.gcnArchName); + std::string hipArch = gcnArchName.substr(0, 6); + return hipArch; +} + +const char *getKernelNameRef(const hipFunction_t f) { + typedef const char *(*hipKernelNameRef_t)(const hipFunction_t); + static hipKernelNameRef_t func = nullptr; + Dispatch::init(ExternLibHip::name, &ExternLibHip::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHip::lib, "hipKernelNameRef")); + return (func ? func(f) : NULL); +} + +const char *getKernelNameRefByPtr(const void *hostFunction, + hipStream_t stream) { + typedef const char *(*hipKernelNameRefByPtr_t)(const void *, hipStream_t); + static hipKernelNameRefByPtr_t func = nullptr; + Dispatch::init(ExternLibHip::name, &ExternLibHip::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHip::lib, "hipKernelNameRefByPtr")); + return (func ? func(hostFunction, stream) : NULL); +} + +} // namespace hip + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/HsaApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/HsaApi.cpp new file mode 100644 index 0000000000..7c607b4b99 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/HsaApi.cpp @@ -0,0 +1,36 @@ +#include "Driver/GPU/HsaApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace hsa { + +struct ExternLibHsa : public ExternLibBase { + using RetType = hsa_status_t; + static constexpr const char *name = "libhsa-runtime64.so"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = HSA_STATUS_SUCCESS; + static void *lib; +}; + +void *ExternLibHsa::lib = nullptr; + +DEFINE_DISPATCH(ExternLibHsa, agentGetInfo, hsa_agent_get_info, hsa_agent_t, + hsa_agent_info_t, void *); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data) { + typedef hsa_status_t (*hsa_iterate_agents_t)( + hsa_status_t (*)(hsa_agent_t, void *), void *data); + static hsa_iterate_agents_t func = nullptr; + Dispatch::init(ExternLibHsa::name, &ExternLibHsa::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHsa::lib, "hsa_iterate_agents")); + return (func ? func(callback, data) : HSA_STATUS_ERROR_FATAL); +} + +} // namespace hsa + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/NvtxApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/NvtxApi.cpp new file mode 100644 index 0000000000..2c5b24d635 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/NvtxApi.cpp @@ -0,0 +1,39 @@ +#include "Driver/GPU/NvtxApi.h" +#include "Driver/GPU/CuptiApi.h" + +#include +#include + +namespace proton { + +namespace { + +// Declare nvtx function params without including the nvtx header +struct RangePushAParams { + const char *message; +}; + +} // namespace + +namespace nvtx { + +void enable() { + // Get cupti lib path and append it to NVTX_INJECTION64_PATH + const std::string cuptiLibPath = + Dispatch::getLibPath(); + if (!cuptiLibPath.empty()) { + setenv("NVTX_INJECTION64_PATH", cuptiLibPath.c_str(), 1); + } +} + +void disable() { unsetenv("NVTX_INJECTION64_PATH"); } + +std::string getMessageFromRangePushA(const void *params) { + if (const auto *p = static_cast(params)) + return std::string(p->message ? p->message : ""); + return ""; +} + +} // namespace nvtx + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp b/third_party/mthreads/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp new file mode 100644 index 0000000000..25d22b9420 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp @@ -0,0 +1,95 @@ +#include "Driver/GPU/RoctracerApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace roctracer { + +DEFINE_DISPATCH(ExternLibRoctracer, setProperties, roctracer_set_properties, + roctracer_domain_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, getTimestamp, roctracer_get_timestamp, + roctracer_timestamp_t *) + +void start() { + typedef void (*roctracer_start_t)(); + static roctracer_start_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_start")); + if (func) + func(); +} + +void stop() { + typedef void (*roctracer_stop_t)(); + static roctracer_stop_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_stop")); + if (func) + func(); +} + +char *getOpString(uint32_t domain, uint32_t op, uint32_t kind) { + typedef char *(*roctracer_op_string_t)(uint32_t, uint32_t, uint32_t); + static roctracer_op_string_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_op_string")); + return (func ? func(domain, op, kind) : NULL); +} + +DEFINE_DISPATCH(ExternLibRoctracer, enableDomainCallback, + roctracer_enable_domain_callback, activity_domain_t, + activity_rtapi_callback_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, disableDomainCallback, + roctracer_disable_domain_callback, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, enableOpCallback, + roctracer_enable_op_callback, activity_domain_t, uint32_t, + activity_rtapi_callback_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, disableOpCallback, + roctracer_disable_op_callback, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, openPool, roctracer_open_pool, + const roctracer_properties_t *) + +DEFINE_DISPATCH(ExternLibRoctracer, closePool, roctracer_close_pool) + +DEFINE_DISPATCH(ExternLibRoctracer, enableOpActivity, + roctracer_enable_op_activity, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, enableDomainActivity, + roctracer_enable_domain_activity, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, disableOpActivity, + roctracer_disable_op_activity, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, disableDomainActivity, + roctracer_disable_domain_activity, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, flushActivity, roctracer_flush_activity) + +DEFINE_DISPATCH(ExternLibRoctracer, activityPushExternalCorrelationId, + roctracer_activity_push_external_correlation_id, + activity_correlation_id_t) + +DEFINE_DISPATCH(ExternLibRoctracer, activityPopExternalCorrelationId, + roctracer_activity_pop_external_correlation_id, + activity_correlation_id_t *) + +DEFINE_DISPATCH(ExternLibRoctracer, getNextRecord, roctracer_next_record, + const activity_record_t *, const activity_record_t **) + +} // namespace roctracer + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Profiler/CMakeLists.txt new file mode 100644 index 0000000000..00dcaef972 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/CMakeLists.txt @@ -0,0 +1,10 @@ +add_proton_library(ProtonProfiler + Profiler.cpp + GPUProfiler.cpp + Graph.cpp + Cupti/CuptiPCSampling.cpp + Cupti/CuptiProfiler.cpp + RocTracer/RoctracerProfiler.cpp + Instrumentation/InstrumentationProfiler.cpp + Instrumentation/Metadata.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp new file mode 100644 index 0000000000..7040231c0f --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -0,0 +1,456 @@ +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Data/Metric.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Atomic.h" +#include "Utility/Map.h" +#include "Utility/String.h" +#include +#include +#include + +namespace proton { + +namespace { + +uint64_t getCubinCrc(const char *cubin, size_t size) { + CUpti_GetCubinCrcParams cubinCrcParams = { + /*size=*/CUpti_GetCubinCrcParamsSize, + /*cubinSize=*/size, + /*cubin=*/cubin, + /*cubinCrc=*/0, + }; + cupti::getCubinCrc(&cubinCrcParams); + return cubinCrcParams.cubinCrc; +} + +size_t getNumStallReasons(CUcontext context) { + size_t numStallReasons = 0; + CUpti_PCSamplingGetNumStallReasonsParams numStallReasonsParams = { + /*size=*/CUpti_PCSamplingGetNumStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/&numStallReasons}; + cupti::pcSamplingGetNumStallReasons(&numStallReasonsParams); + return numStallReasons; +} + +std::tuple +getSassToSourceCorrelation(const char *functionName, uint64_t pcOffset, + const char *cubin, size_t cubinSize) { + CUpti_GetSassToSourceCorrelationParams sassToSourceParams = { + /*size=*/CUpti_GetSassToSourceCorrelationParamsSize, + /*cubin=*/cubin, + /*functionName=*/functionName, + /*cubinSize=*/cubinSize, + /*lineNumber=*/0, + /*pcOffset=*/pcOffset, + /*fileName=*/NULL, + /*dirName=*/NULL, + }; + // Get source can fail if the line mapping is not available in the cubin so we + // don't check the return value + cupti::getSassToSourceCorrelation(&sassToSourceParams); + auto fileNameStr = sassToSourceParams.fileName + ? std::string(sassToSourceParams.fileName) + : ""; + auto dirNameStr = + sassToSourceParams.dirName ? std::string(sassToSourceParams.dirName) : ""; + // It's user's responsibility to free the memory + if (sassToSourceParams.fileName) + std::free(sassToSourceParams.fileName); + if (sassToSourceParams.dirName) + std::free(sassToSourceParams.dirName); + return std::make_tuple(sassToSourceParams.lineNumber, fileNameStr, + dirNameStr); +} + +std::pair +getStallReasonNamesAndIndices(CUcontext context, size_t numStallReasons) { + char **stallReasonNames = + static_cast(std::calloc(numStallReasons, sizeof(char *))); + for (size_t i = 0; i < numStallReasons; i++) { + stallReasonNames[i] = static_cast( + std::calloc(CUPTI_STALL_REASON_STRING_SIZE, sizeof(char))); + } + uint32_t *stallReasonIndices = + static_cast(std::calloc(numStallReasons, sizeof(uint32_t))); + // Initialize the names with 128 characters to avoid buffer overflow + CUpti_PCSamplingGetStallReasonsParams stallReasonsParams = { + /*size=*/CUpti_PCSamplingGetStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/numStallReasons, + /*stallReasonIndex=*/stallReasonIndices, + /*stallReasons=*/stallReasonNames, + }; + cupti::pcSamplingGetStallReasons(&stallReasonsParams); + return std::make_pair(stallReasonNames, stallReasonIndices); +} + +size_t matchStallReasonsToIndices( + size_t numStallReasons, char **stallReasonNames, + uint32_t *stallReasonIndices, + std::map &stallReasonIndexToMetricIndex, + std::set ¬IssuedStallReasonIndices) { + // In case there's any invalid stall reasons, we only collect valid ones. + // Invalid ones are swapped to the end of the list + std::vector validIndex(numStallReasons, false); + size_t numValidStalls = 0; + for (size_t i = 0; i < numStallReasons; i++) { + bool notIssued = std::string(stallReasonNames[i]).find("not_issued") != + std::string::npos; + std::string cuptiStallName = std::string(stallReasonNames[i]); + for (size_t j = 0; j < PCSamplingMetric::PCSamplingMetricKind::Count; j++) { + auto metricName = PCSamplingMetric().getValueName(j); + if (cuptiStallName.find(metricName) != std::string::npos) { + if (notIssued) + notIssuedStallReasonIndices.insert(stallReasonIndices[i]); + stallReasonIndexToMetricIndex[stallReasonIndices[i]] = j; + validIndex[i] = true; + numValidStalls++; + break; + } + } + } + int invalidIndex = -1; + for (size_t i = 0; i < numStallReasons; i++) { + if (invalidIndex == -1 && !validIndex[i]) { + invalidIndex = i; + } else if (invalidIndex != -1 && validIndex[i]) { + std::swap(stallReasonIndices[invalidIndex], stallReasonIndices[i]); + std::swap(stallReasonNames[invalidIndex], stallReasonNames[i]); + validIndex[invalidIndex] = true; + invalidIndex++; + } + } + return numValidStalls; +} + +#define CUPTI_CUDA12_4_VERSION 22 +#define CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE sizeof(uint32_t) + +CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, + size_t numValidStallReasons) { + uint32_t libVersion = 0; + cupti::getVersion(&libVersion); + size_t pcDataSize = sizeof(CUpti_PCSamplingPCData); + // Since CUPTI 12.4, a new field (i.e., correlationId) is added to + // CUpti_PCSamplingPCData, which breaks the ABI compatibility. + // Instead of using workarounds, we emit an error message and exit the + // application. + if ((libVersion < CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) || + (libVersion >= CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION < CUPTI_CUDA12_4_VERSION)) { + throw std::runtime_error( + "[PROTON] CUPTI API version: " + std::to_string(CUPTI_API_VERSION) + + " and CUPTI driver version: " + std::to_string(libVersion) + + " are not compatible. Please set the environment variable " + " TRITON_CUPTI_INCLUDE_PATH and TRITON_CUPTI_LIB_PATH to resolve the " + "problem."); + } + CUpti_PCSamplingData pcSamplingData{ + /*size=*/sizeof(CUpti_PCSamplingData), + /*collectNumPcs=*/collectNumPCs, + /*totalSamples=*/0, + /*droppedSamples=*/0, + /*totalNumPcs=*/0, + /*remainingNumPcs=*/0, + /*rangeId=*/0, + /*pPcData=*/ + static_cast( + std::calloc(collectNumPCs, sizeof(CUpti_PCSamplingPCData)))}; + for (size_t i = 0; i < collectNumPCs; ++i) { + pcSamplingData.pPcData[i].stallReason = + static_cast(std::calloc( + numValidStallReasons, sizeof(CUpti_PCSamplingStallReason))); + } + return pcSamplingData; +} + +void enablePCSampling(CUcontext context) { + CUpti_PCSamplingEnableParams params = { + /*size=*/CUpti_PCSamplingEnableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingEnable(¶ms); +} + +void disablePCSampling(CUcontext context) { + CUpti_PCSamplingDisableParams params = { + /*size=*/CUpti_PCSamplingDisableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingDisable(¶ms); +} + +void startPCSampling(CUcontext context) { + CUpti_PCSamplingStartParams params = { + /*size=*/CUpti_PCSamplingStartParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingStart(¶ms); +} + +void stopPCSampling(CUcontext context) { + CUpti_PCSamplingStopParams params = { + /*size=*/CUpti_PCSamplingStopParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingStop(¶ms); +} + +void getPCSamplingData(CUcontext context, + CUpti_PCSamplingData *pcSamplingData) { + CUpti_PCSamplingGetDataParams params = { + /*size=*/CUpti_PCSamplingGetDataParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*pcSamplingData=*/pcSamplingData, + }; + cupti::pcSamplingGetData(¶ms); +} + +void setConfigurationAttribute( + CUcontext context, + std::vector &configurationInfos) { + CUpti_PCSamplingConfigurationInfoParams infoParams = { + /*size=*/CUpti_PCSamplingConfigurationInfoParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numAttributes=*/configurationInfos.size(), + /*pPCSamplingConfigurationInfo=*/configurationInfos.data(), + }; + cupti::pcSamplingSetConfigurationAttribute(&infoParams); +} + +} // namespace + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStallReasons() { + numStallReasons = getNumStallReasons(context); + std::tie(this->stallReasonNames, this->stallReasonIndices) = + getStallReasonNamesAndIndices(context, numStallReasons); + numValidStallReasons = matchStallReasonsToIndices( + numStallReasons, stallReasonNames, stallReasonIndices, + stallReasonIndexToMetricIndex, notIssuedStallReasonIndices); + CUpti_PCSamplingConfigurationInfo stallReasonInfo{}; + stallReasonInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON; + stallReasonInfo.attributeData.stallReasonData.stallReasonCount = + numValidStallReasons; + stallReasonInfo.attributeData.stallReasonData.pStallReasonIndex = + stallReasonIndices; + return stallReasonInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingPeriod() { + CUpti_PCSamplingConfigurationInfo samplingPeriodInfo{}; + samplingPeriodInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD; + samplingPeriodInfo.attributeData.samplingPeriodData.samplingPeriod = + DefaultFrequency; + return samplingPeriodInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingBuffer() { + CUpti_PCSamplingConfigurationInfo samplingBufferInfo{}; + samplingBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER; + this->pcSamplingData = + allocPCSamplingData(DataBufferPCCount, numValidStallReasons); + samplingBufferInfo.attributeData.samplingDataBufferData.samplingDataBuffer = + &this->pcSamplingData; + return samplingBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureScratchBuffer() { + CUpti_PCSamplingConfigurationInfo scratchBufferInfo{}; + scratchBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE; + scratchBufferInfo.attributeData.scratchBufferSizeData.scratchBufferSize = + ScratchBufferSize; + return scratchBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureHardwareBufferSize() { + CUpti_PCSamplingConfigurationInfo hardwareBufferInfo{}; + hardwareBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE; + hardwareBufferInfo.attributeData.hardwareBufferSizeData.hardwareBufferSize = + HardwareBufferSize; + return hardwareBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStartStopControl() { + CUpti_PCSamplingConfigurationInfo startStopControlInfo{}; + startStopControlInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL; + startStopControlInfo.attributeData.enableStartStopControlData + .enableStartStopControl = true; + return startStopControlInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureCollectionMode() { + CUpti_PCSamplingConfigurationInfo collectionModeInfo{}; + collectionModeInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE; + collectionModeInfo.attributeData.collectionModeData.collectionMode = + CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS; + return collectionModeInfo; +} + +void ConfigureData::initialize(CUcontext context) { + this->context = context; + cupti::getContextId(context, &contextId); + configurationInfos.emplace_back(configureStallReasons()); + configurationInfos.emplace_back(configureSamplingPeriod()); + configurationInfos.emplace_back(configureHardwareBufferSize()); + configurationInfos.emplace_back(configureScratchBuffer()); + configurationInfos.emplace_back(configureSamplingBuffer()); + configurationInfos.emplace_back(configureStartStopControl()); + configurationInfos.emplace_back(configureCollectionMode()); + setConfigurationAttribute(context, configurationInfos); +} + +ConfigureData *CuptiPCSampling::getConfigureData(uint32_t contextId) { + return &contextIdToConfigureData[contextId]; +} + +CubinData *CuptiPCSampling::getCubinData(uint64_t cubinCrc) { + return &(cubinCrcToCubinData[cubinCrc].first); +} + +void CuptiPCSampling::initialize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() { return !contextInitialized.contain(contextId); }, + contextMutex, + [&]() { + enablePCSampling(context); + getConfigureData(contextId)->initialize(context); + contextInitialized.insert(contextId); + }); +} + +void CuptiPCSampling::start(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return !pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + initialize(context); + // Ensure all previous operations are completed + cuda::ctxSynchronize(); + startPCSampling(context); + pcSamplingStarted = true; + }); +} + +void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, + const DataToEntryMap &dataToEntry) { + auto *pcSamplingData = &configureData->pcSamplingData; + auto &profiler = CuptiProfiler::instance(); + // In the first round, we need to call getPCSamplingData to get the unsynced + // data from the hardware buffer + bool firstRound = true; + while (pcSamplingData->totalNumPcs > 0 || + pcSamplingData->remainingNumPcs > 0 || firstRound) { + // Handle data + for (size_t i = 0; i < pcSamplingData->totalNumPcs; ++i) { + auto *pcData = pcSamplingData->pPcData + i; + auto *cubinData = getCubinData(pcData->cubinCrc); + auto key = + CubinData::LineInfoKey{pcData->functionIndex, pcData->pcOffset}; + if (cubinData->lineInfo.find(key) == cubinData->lineInfo.end()) { + auto [lineNumber, fileName, dirName] = + getSassToSourceCorrelation(pcData->functionName, pcData->pcOffset, + cubinData->cubin, cubinData->cubinSize); + cubinData->lineInfo.try_emplace(key, lineNumber, + std::string(pcData->functionName), + dirName, fileName); + } + auto &lineInfo = cubinData->lineInfo[key]; + for (size_t j = 0; j < pcData->stallReasonCount; ++j) { + auto *stallReason = &pcData->stallReason[j]; + if (!configureData->stallReasonIndexToMetricIndex.count( + stallReason->pcSamplingStallReasonIndex)) + throw std::runtime_error("[PROTON] Invalid stall reason index"); + for (auto [data, entry] : dataToEntry) { + if (lineInfo.fileName.size()) + entry = + data->addOp(entry.phase, entry.id, + {formatFileLineFunction( + lineInfo.dirName + "/" + lineInfo.fileName, + lineInfo.lineNumber, lineInfo.functionName)}); + auto metricKind = static_cast( + configureData->stallReasonIndexToMetricIndex + [stallReason->pcSamplingStallReasonIndex]); + auto samples = stallReason->samples; + auto stalledSamples = + configureData->notIssuedStallReasonIndices.count( + stallReason->pcSamplingStallReasonIndex) + ? 0 + : samples; + entry.upsertMetric(std::make_unique( + metricKind, samples, stalledSamples)); + } + } + } + if (pcSamplingData->remainingNumPcs > 0 || firstRound) { + getPCSamplingData(configureData->context, pcSamplingData); + firstRound = false; + } else + break; + } +} + +void CuptiPCSampling::stop(CUcontext context, + const DataToEntryMap &dataToEntry) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + auto *configureData = getConfigureData(contextId); + stopPCSampling(context); + pcSamplingStarted = false; + processPCSamplingData(configureData, dataToEntry); + }); +} + +void CuptiPCSampling::finalize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + if (!contextInitialized.contain(contextId)) + return; + auto *configureData = getConfigureData(contextId); + contextIdToConfigureData.erase(contextId); + contextInitialized.erase(contextId); + disablePCSampling(context); +} + +void CuptiPCSampling::loadModule(const char *cubin, size_t cubinSize) { + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto *cubinData = getCubinData(cubinCrc); + cubinData->cubinCrc = cubinCrc; + cubinData->cubinSize = cubinSize; + cubinData->cubin = cubin; +} + +void CuptiPCSampling::unloadModule(const char *cubin, size_t cubinSize) { + // XXX: Unload module is supposed to be called in a thread safe manner + // i.e., no two threads will be calling unload module the same time + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto count = cubinCrcToCubinData[cubinCrc].second; + if (count > 1) + cubinCrcToCubinData[cubinCrc].second = count - 1; + else + cubinCrcToCubinData.erase(cubinCrc); +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp new file mode 100644 index 0000000000..3954feb86e --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -0,0 +1,816 @@ +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Device.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Driver/GPU/NvtxApi.h" +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Profiler/Graph.h" +#include "Runtime/CudaRuntime.h" +#include "Utility/Env.h" +#include "Utility/Map.h" +#include "Utility/String.h" +#include "Utility/Vector.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(CuptiProfiler::instance()); + +namespace { + +std::unique_ptr +convertKernelActivityToMetric(CUpti_Activity *activity) { + std::unique_ptr metric; + auto *kernel = reinterpret_cast(activity); + if (kernel->start < kernel->end) { + metric = + std::make_unique(static_cast(kernel->start), + static_cast(kernel->end), 1, + static_cast(kernel->deviceId), + static_cast(DeviceType::CUDA), + static_cast(kernel->streamId)); + } // else: not a valid kernel activity + return metric; +} + +uint32_t processActivityKernel( + CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ExternIdToStateMap &externIdToState, + std::map> + &externIdToStateCache, + std::map> &dataPhases, + CUpti_Activity *activity) { + // Support CUDA >= 11.0 + auto *kernel = reinterpret_cast(activity); + auto correlationId = kernel->correlationId; + size_t externId = 0; + if (!/*not valid*/ corrIdToExternId.withRead( + correlationId, [&externId](size_t value) { externId = value; })) { + corrIdToExternId.erase(correlationId); + return correlationId; + } + if (kernel->graphId == 0) { // XXX: This is a misnomer confirmed by NVIDIA, + // actually it refers to graphExecId + // Non-graph kernels + bool isMissingName = false; + DataToEntryMap dataToEntry; + externIdToState.withRead(externId, + [&](const CuptiProfiler::ExternIdState &state) { + isMissingName = state.isMissingName; + dataToEntry = state.dataToEntry; + }); + if (!isMissingName) { + for (auto &[data, entry] : dataToEntry) { + if (auto kernelMetric = convertKernelActivityToMetric(activity)) { + entry.upsertMetric(std::move(kernelMetric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } else { + for (auto &[data, entry] : dataToEntry) { + if (auto kernelMetric = convertKernelActivityToMetric(activity)) { + auto childEntry = + data->addOp(entry.phase, entry.id, {Context(kernel->name)}); + childEntry.upsertMetric(std::move(kernelMetric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } + externIdToState.erase(externId); + corrIdToExternId.erase(correlationId); + } else { + // Graph kernels + // A single graph launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // If graph creation has been captured: + // - parentId, nodeId -> launch context + capture context + // Otherwise: + // - parentId -> launch context + // --- CUPTI thread --- + // - corrId -> numNodes + auto iter = externIdToStateCache.find(externId); + CuptiProfiler::ExternIdState *state = nullptr; + if (iter != externIdToStateCache.end()) { + state = &iter->second.get(); + } else { + // Cache miss, fetch from the main map + auto ref = externIdToState.find(externId); + // Update the cache + externIdToStateCache.emplace(externId, ref.value()); + state = &ref.value().get(); + } + auto &externState = *state; + // We have a graph creation captured + auto &graphNodeIdToState = externState.graphNodeIdToState; + auto *nodeState = graphNodeIdToState.find(kernel->graphNodeId); + if (nodeState && !nodeState->isMetricNode) { + const bool isMissingName = nodeState->isMissingName; + if (!isMissingName) { + nodeState->forEachEntry( + [activity, &dataPhases](Data *data, DataEntry &entry) { + if (auto kernelMetric = convertKernelActivityToMetric(activity)) { + entry.upsertMetric(std::move(kernelMetric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + }); + } else { + nodeState->forEachEntry( + [kernel, activity, &dataPhases](Data *data, DataEntry &entry) { + if (auto kernelMetric = convertKernelActivityToMetric(activity)) { + auto childEntry = + data->addOp(entry.phase, entry.id, {Context(kernel->name)}); + childEntry.upsertMetric(std::move(kernelMetric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + }); + } + } + // Decrease the expected kernel count + if (externState.numNodes > 0) { + externState.numNodes--; + } + // If all kernels have been processed, clean up + if (externState.numNodes == 0) { + externIdToState.erase(externId); + corrIdToExternId.erase(correlationId); + } + } + return correlationId; +} + +uint32_t processActivity( + CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ExternIdToStateMap &externIdToState, + std::map> + &externIdToStateCache, + std::map> &dataPhases, + CUpti_Activity *activity) { + auto correlationId = 0; + switch (activity->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { + correlationId = + processActivityKernel(corrIdToExternId, externIdToState, + externIdToStateCache, dataPhases, activity); + break; + } + default: + break; + } + return correlationId; +} + +constexpr std::array kGraphCallbacks = { + CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch, + CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_ptsz, + CUPTI_DRIVER_TRACE_CBID_cuStreamEndCapture, + CUPTI_DRIVER_TRACE_CBID_cuStreamEndCapture_ptsz, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_v2, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_v2_ptsz, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCaptureToGraph, + CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCaptureToGraph_ptsz, + CUPTI_DRIVER_TRACE_CBID_cuStreamEndCapture}; + +#define PROTON_KERNEL_CALLBACK_LIST(X) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunch) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz) \ + X(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice) + +#define PROTON_KERNEL_CB_AS_ID(cbId) cbId, +constexpr std::array kKernelCallbacks = { + PROTON_KERNEL_CALLBACK_LIST(PROTON_KERNEL_CB_AS_ID)}; +#undef PROTON_KERNEL_CB_AS_ID + +constexpr std::array kGraphResourceCallbacks = { + CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED, + CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED, + CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING, + CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED, + CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING, +}; + +constexpr std::array kResourceCallbacks = { + CUPTI_CBID_RESOURCE_MODULE_LOADED, + CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING, + CUPTI_CBID_RESOURCE_CONTEXT_CREATED, + CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING, +}; + +constexpr std::array kNvtxCallbacks = { + CUPTI_CBID_NVTX_nvtxRangePushA, + CUPTI_CBID_NVTX_nvtxRangePop, +}; + +void setLaunchCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + for (auto cbId : kKernelCallbacks) { + cupti::enableCallback(static_cast(enable), subscriber, + CUPTI_CB_DOMAIN_DRIVER_API, cbId); + } +} + +void setGraphCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + for (auto cbId : kGraphCallbacks) { + cupti::enableCallback(static_cast(enable), subscriber, + CUPTI_CB_DOMAIN_DRIVER_API, cbId); + } + for (auto cbId : kGraphResourceCallbacks) { + cupti::enableCallback(static_cast(enable), subscriber, + CUPTI_CB_DOMAIN_RESOURCE, cbId); + } +} + +void setResourceCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + for (auto cbId : kResourceCallbacks) { + cupti::enableCallback(static_cast(enable), subscriber, + CUPTI_CB_DOMAIN_RESOURCE, cbId); + } +} + +void setNvtxCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + for (auto cbId : kNvtxCallbacks) { + cupti::enableCallback(static_cast(enable), subscriber, + CUPTI_CB_DOMAIN_NVTX, cbId); + } +} + +bool isKernel(CUpti_CallbackId cbId) { + switch (cbId) { +#define PROTON_KERNEL_CB_AS_CASE(cbId) \ + case cbId: \ + return true; + PROTON_KERNEL_CALLBACK_LIST(PROTON_KERNEL_CB_AS_CASE) +#undef PROTON_KERNEL_CB_AS_CASE + default: + return false; + } +} + +bool isGraphLaunch(CUpti_CallbackId cbId) { + return cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz; +} + +bool isLaunch(CUpti_CallbackId cbId) { + return isKernel(cbId) || isGraphLaunch(cbId); +} + +#undef PROTON_KERNEL_CALLBACK_LIST + +} // namespace + +struct CuptiProfiler::CuptiProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + CuptiProfilerPimpl(CuptiProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) { + auto runtime = &CudaRuntime::instance(); + profiler.metricBuffer = + std::make_unique(1024 * 1024 * 64, runtime, + /*mapped=*/true); + profiler.pendingGraphPool = + std::make_unique(profiler.metricBuffer.get()); + } + virtual ~CuptiProfilerPimpl() = default; + + void doStart() override; + void doFlush() override; + void doStop() override; + + static void allocBuffer(uint8_t **buffer, size_t *bufferSize, + size_t *maxNumRecords); + static void completeBuffer(CUcontext context, uint32_t streamId, + uint8_t *buffer, size_t size, size_t validSize); + static void callbackFn(void *userData, CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, const void *cbData); + + static constexpr size_t AlignSize = 8; + static constexpr size_t AttributeSize = sizeof(size_t); + static constexpr const char *CaptureTag = ""; + + CUpti_SubscriberHandle subscriber{}; + CuptiPCSampling pcSampling; + + ThreadSafeMap graphStates; + +private: + void handleGraphResourceCallbacks(CuptiProfiler &profiler, + CUpti_CallbackId cbId, + CUpti_GraphData *graphData); + void handleResourceCallbacks(CuptiProfiler &profiler, CUpti_CallbackId cbId, + const void *cbData); + void handleNvtxCallbacks(CUpti_CallbackId cbId, const void *cbData); + + bool handleStreamCaptureCallbacks(CUpti_CallbackId cbId); + void handleApiEnterLaunchCallbacks(CuptiProfiler &profiler, + CUpti_CallbackId cbId, + const CUpti_CallbackData *callbackData); + void handleApiExitLaunchCallbacks(CuptiProfiler &profiler, + CUpti_CallbackId cbId, + const CUpti_CallbackData *callbackData); + void handleApiCallbacks(CuptiProfiler &profiler, CUpti_CallbackId cbId, + const void *cbData); +}; + +void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer, + size_t *bufferSize, + size_t *maxNumRecords) { + const auto envBufferSize = + getIntEnv("TRITON_PROFILE_BUFFER_SIZE", 64 * 1024 * 1024); + *buffer = static_cast(aligned_alloc(AlignSize, envBufferSize)); + if (*buffer == nullptr) { + throw std::runtime_error("[PROTON] aligned_alloc failed"); + } + *bufferSize = envBufferSize; + *maxNumRecords = 0; +} + +void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, + uint32_t streamId, + uint8_t *buffer, + size_t size, + size_t validSize) { + CuptiProfiler &profiler = threadState.profiler; + uint32_t maxCorrelationId = 0; + static thread_local std::map dataFlushedPhases; + std::map> dataPhases; + CUptiResult status; + CUpti_Activity *activity = nullptr; + std::map> + externIdToStateCache; + do { + status = cupti::activityGetNextRecord(buffer, validSize, &activity); + if (status == CUPTI_SUCCESS) { + auto correlationId = + processActivity(profiler.correlation.corrIdToExternId, + profiler.correlation.externIdToState, + externIdToStateCache, dataPhases, activity); + maxCorrelationId = std::max(maxCorrelationId, correlationId); + } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + throw std::runtime_error("[PROTON] cupti::activityGetNextRecord failed"); + } + } while (true); + + std::free(buffer); + + profiler.correlation.complete(maxCorrelationId); + profiler.flushDataPhases(dataFlushedPhases, dataPhases, + profiler.pendingGraphPool.get()); +} + +void CuptiProfiler::CuptiProfilerPimpl::handleGraphResourceCallbacks( + CuptiProfiler &profiler, CUpti_CallbackId cbId, + CUpti_GraphData *graphData) { + uint32_t graphId = 0; + uint32_t graphExecId = 0; + if (graphData->graph) + cupti::getGraphId(graphData->graph, &graphId); + if (graphData->graphExec) + cupti::getGraphExecId(graphData->graphExec, &graphExecId); + if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED || + cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED) { + uint64_t nodeId = 0; + cupti::getGraphNodeId(graphData->node, &nodeId); + if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED) { + // When `cuGraphClone` or `cuGraphInstantiate` is called, CUPTI triggers + // both CREATED and CLONED callbacks for each node. So we only increase + // the numNodes in CREATED callback. + if (!graphStates.contain(graphId)) + graphStates[graphId] = GraphState(); + else + graphStates[graphId].numNodes++; + if (profiler.isOpInProgress()) { + auto &graphState = graphStates[graphId]; + auto &nodeState = graphState.nodeIdToState[nodeId]; + nodeState.nodeId = nodeId; + const auto &name = threadState.scopeStack.back().name; + if (name.empty() || (threadState.isApiExternOp && + threadState.isMetricKernelLaunching)) { + nodeState.isMissingName = true; + } + if (threadState.isMetricKernelLaunching) { + nodeState.isMetricNode = true; + graphState.metricKernelNodeIds.insert(nodeId); + } + for (auto *data : profiler.dataSet) { + auto contexts = data->getContexts(); + if (!threadState.isApiExternOp || + !threadState.isMetricKernelLaunching) + contexts.push_back(name); + nodeState.captureContexts[data] = std::move(contexts); + graphState + .dataToCallpathToNodeStates[data][nodeState.captureContexts[data]] + .push_back(std::ref(nodeState)); + } + } // else no op in progress; creation triggered by graph clone/instantiate + } else { // CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED + uint32_t originalGraphId = 0; + uint64_t originalNodeId = 0; + cupti::getGraphId(graphData->originalGraph, &originalGraphId); + cupti::getGraphNodeId(graphData->originalNode, &originalNodeId); + auto &graphState = graphStates[graphId]; + // Clone all node states. + graphState.nodeIdToState[nodeId] = + graphStates[originalGraphId].nodeIdToState[originalNodeId]; + auto &nodeState = graphState.nodeIdToState[nodeId]; + nodeState.nodeId = nodeId; + for (const auto &[data, callpath] : nodeState.captureContexts) { + graphState.dataToCallpathToNodeStates[data][callpath].push_back( + std::ref(nodeState)); + } + if (graphStates[originalGraphId].metricKernelNodeIds.find( + originalNodeId) != + graphStates[originalGraphId].metricKernelNodeIds.end()) { + graphState.metricKernelNodeIds.insert(nodeId); + } + } + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { + auto &numNodes = graphStates[graphId].numNodes; + numNodes--; + uint64_t nodeId = 0; + cupti::getGraphNodeId(graphData->node, &nodeId); + auto &graphState = graphStates[graphId]; + for (const auto &[data, callpath] : + graphState.nodeIdToState[nodeId].captureContexts) { + auto &nodeStates = graphState.dataToCallpathToNodeStates[data][callpath]; + nodeStates.erase( + std::remove_if(nodeStates.begin(), nodeStates.end(), + [nodeId](const GraphState::NodeStateRef &state) { + return state.get().nodeId == nodeId; + }), + nodeStates.end()); + } + graphState.nodeIdToState.erase(nodeId); + graphState.metricKernelNodeIds.erase(nodeId); + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { + graphStates.erase(graphId); + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { + graphStates.erase(graphExecId); + } +} + +void CuptiProfiler::CuptiProfilerPimpl::handleResourceCallbacks( + CuptiProfiler &profiler, CUpti_CallbackId cbId, const void *cbData) { + auto *resourceData = + static_cast(const_cast(cbData)); + if (cbId == CUPTI_CBID_RESOURCE_MODULE_LOADED) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.pcSamplingEnabled) + pcSampling.loadModule(moduleResource->pCubin, moduleResource->cubinSize); + } else if (cbId == CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.pcSamplingEnabled) + pcSampling.unloadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { + if (profiler.pcSamplingEnabled) + pcSampling.initialize(resourceData->context); + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { + if (profiler.pcSamplingEnabled) + pcSampling.finalize(resourceData->context); + } else { + auto *graphData = + static_cast(resourceData->resourceDescriptor); + handleGraphResourceCallbacks(profiler, cbId, graphData); + } +} + +void CuptiProfiler::CuptiProfilerPimpl::handleNvtxCallbacks( + CUpti_CallbackId cbId, const void *cbData) { + auto *nvtxData = static_cast(cbData); + if (cbId == CUPTI_CBID_NVTX_nvtxRangePushA) { + auto message = nvtx::getMessageFromRangePushA(nvtxData->functionParams); + threadState.enterScope(message); + } else if (cbId == CUPTI_CBID_NVTX_nvtxRangePop) { + threadState.exitScope(); + } // TODO: else handle other NVTX range functions +} + +bool CuptiProfiler::CuptiProfilerPimpl::handleStreamCaptureCallbacks( + CUpti_CallbackId cbId) { + if (cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture || + cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_v2 || + cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamBeginCapture_v2_ptsz) { + threadState.isStreamCapturing = true; + profiler.metricBuffer->reserve(); + return true; + } + if (cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamEndCapture || + cbId == CUPTI_DRIVER_TRACE_CBID_cuStreamEndCapture_ptsz) { + threadState.isStreamCapturing = false; + return true; + } + return false; +} + +void CuptiProfiler::CuptiProfilerPimpl::handleApiEnterLaunchCallbacks( + CuptiProfiler &profiler, CUpti_CallbackId cbId, + const CUpti_CallbackData *callbackData) { + if (handleStreamCaptureCallbacks(cbId)) + return; + if (!isLaunch(cbId)) + return; + + size_t numNodes = 1; + if (isGraphLaunch(cbId)) { + threadState.enterOp(Scope("")); + } else { + // Symbol name is only available for kernel launch APIs. + const auto symbolName = callbackData->context && callbackData->symbolName + ? std::string(callbackData->symbolName) + : ""; + threadState.enterOp(Scope(symbolName)); + } + + const auto &scope = threadState.scopeStack.back(); + auto &dataToEntry = threadState.dataToEntry; + if (isGraphLaunch(cbId)) { + auto graphExec = + static_cast(callbackData->functionParams) + ->hGraph; + uint32_t graphExecId = 0; + cupti::getGraphExecId(graphExec, &graphExecId); + numNodes = std::numeric_limits::max(); + auto findGraph = false; + if (graphStates.contain(graphExecId)) { + numNodes = graphStates[graphExecId].numNodes; + findGraph = true; + } + if (!findGraph && !graphStates[graphExecId].captureStatusChecked) { + graphStates[graphExecId].captureStatusChecked = true; + std::cerr << "[PROTON] Cannot find graph for graphExecId: " << graphExecId + << ", and t may cause memory leak. To avoid this problem, " + "please start profiling before the graph is created." + << std::endl; + } else if (findGraph) { + auto &graphState = graphStates[graphExecId]; + + // For each unique call path, we generate an entry per data object. + auto &graphNodeIdToState = + profiler.correlation.externIdToState[scope.scopeId] + .graphNodeIdToState; + if (!graphState.nodeIdToState.empty()) { + auto minNodeId = graphState.nodeIdToState.begin()->first; + auto maxNodeId = graphState.nodeIdToState.rbegin()->first; + graphNodeIdToState.resetRange(minNodeId, maxNodeId); + } else { + graphNodeIdToState.clear(); + } + static const bool timingEnabled = + getBoolEnv("PROTON_GRAPH_LAUNCH_TIMING", false); + using Clock = std::chrono::steady_clock; + auto t0 = decltype(Clock::now()){}; + if (timingEnabled) + t0 = Clock::now(); + + for (auto &[data, callpathToNodeStates] : + graphState.dataToCallpathToNodeStates) { + auto *dataPtr = data; + auto entryIt = dataToEntry.find(dataPtr); + if (entryIt == dataToEntry.end()) + continue; + auto baseEntry = + dataPtr->addOp(entryIt->second.phase, entryIt->second.id, + {Context{GraphState::captureTag}}); + for (const auto &[callpath, nodeStates] : callpathToNodeStates) { + const auto nodeEntry = + dataPtr->addOp(baseEntry.phase, baseEntry.id, callpath); + for (const auto &nodeStateRef : nodeStates) { + const auto &nodeState = nodeStateRef.get(); + auto &graphNodeState = graphNodeIdToState.emplace(nodeState.nodeId); + graphNodeState.isMissingName = nodeState.isMissingName; + graphNodeState.isMetricNode = nodeState.isMetricNode; + graphNodeState.setEntry(data, nodeEntry); + } + } + } + if (timingEnabled) { + auto t1 = Clock::now(); + auto elapsed = + std::chrono::duration_cast(t1 - t0) + .count(); + std::cerr << "[PROTON] Graph launch call path time: " << elapsed + << " us for graphExecId: " << graphExecId << std::endl; + t0 = Clock::now(); + } + + if (!graphStates[graphExecId].metricKernelNodeIds.empty()) { + auto &graphExecState = graphStates[graphExecId]; + std::map> metricNodeEntryIds; + auto phase = Data::kNoCompletePhase; + auto numNodes = graphExecState.metricKernelNodeIds.size(); + for (auto nodeId : graphExecState.metricKernelNodeIds) { + auto *nodeState = graphNodeIdToState.find(nodeId); + if (!nodeState) { + throw std::runtime_error( + "[PROTON] Missing graph node state for metric node."); + } + nodeState->forEachEntry([&](Data *data, const DataEntry &entry) { + metricNodeEntryIds[data].push_back(entry.id); + if (phase == Data::kNoCompletePhase) { + phase = entry.phase; + } else if (phase != entry.phase) { + throw std::runtime_error( + "[PROTON] Inconsistent phases in graph metric nodes"); + } + }); + } + // Check if all data contains the same number of metric nodes + for (const auto &[data, entryIds] : metricNodeEntryIds) { + if (entryIds.size() != numNodes) { + throw std::runtime_error( + "[PROTON] Inconsistent number of metric nodes in graph."); + } + } + if (callbackData->context != nullptr) + profiler.pendingGraphPool->flushIfNeeded(numNodes); + profiler.pendingGraphPool->push(phase, metricNodeEntryIds, numNodes); + } + if (timingEnabled) { + auto t1 = Clock::now(); + auto elapsed = + std::chrono::duration_cast(t1 - t0) + .count(); + std::cerr << "[PROTON] Graph launch metric time: " << elapsed + << " us for graphExecId: " << graphExecId << std::endl; + } + } + } + + profiler.correlation.correlate(callbackData->correlationId, scope.scopeId, + numNodes, scope.name.empty(), dataToEntry); + if (profiler.pcSamplingEnabled) + pcSampling.start(callbackData->context); +} + +void CuptiProfiler::CuptiProfilerPimpl::handleApiExitLaunchCallbacks( + CuptiProfiler &profiler, CUpti_CallbackId cbId, + const CUpti_CallbackData *callbackData) { + if (!isLaunch(cbId)) + return; + + if (profiler.pcSamplingEnabled) { + auto &dataToEntry = threadState.dataToEntry; + // XXX: Conservatively stop every GPU kernel for now. + pcSampling.stop(callbackData->context, dataToEntry); + } + + threadState.exitOp(); + profiler.correlation.submit(callbackData->correlationId); +} + +void CuptiProfiler::CuptiProfilerPimpl::handleApiCallbacks( + CuptiProfiler &profiler, CUpti_CallbackId cbId, const void *cbData) { + // Do not track metric kernel launches for triton ops. + // In this case, metric kernels are launched after a triton op is entered. + // We should track metric kernel launches for scopes. In this case, the metric + // kernel's stack has the same name as the scope's stack. + if (threadState.isMetricKernelLaunching && profiler.isOpInProgress()) + return; + + const CUpti_CallbackData *callbackData = + static_cast(cbData); + if (callbackData->callbackSite == CUPTI_API_ENTER) { + handleApiEnterLaunchCallbacks(profiler, cbId, callbackData); + } else if (callbackData->callbackSite == CUPTI_API_EXIT) { + handleApiExitLaunchCallbacks(profiler, cbId, callbackData); + } +} + +void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, + const void *cbData) { + CuptiProfiler &profiler = threadState.profiler; + auto *pImpl = dynamic_cast(profiler.pImpl.get()); + if (domain == CUPTI_CB_DOMAIN_RESOURCE) { + pImpl->handleResourceCallbacks(profiler, cbId, cbData); + } else if (domain == CUPTI_CB_DOMAIN_NVTX) { + pImpl->handleNvtxCallbacks(cbId, cbData); + } else { + pImpl->handleApiCallbacks(profiler, cbId, cbData); + } +} + +void CuptiProfiler::CuptiProfilerPimpl::doStart() { + cupti::subscribe(&subscriber, callbackFn, nullptr); + if (profiler.pcSamplingEnabled) { + setResourceCallbacks(subscriber, /*enable=*/true); + // Continuous PC sampling is not compatible with concurrent kernel profiling + cupti::activityEnable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + if (getBoolEnv("TRITON_ENABLE_HW_TRACE", true)) + cupti::activityEnableHWTrace(/*enable=*/1); + } + cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); + setGraphCallbacks(subscriber, /*enable=*/true); + setLaunchCallbacks(subscriber, /*enable=*/true); + if (getBoolEnv("TRITON_ENABLE_NVTX", true)) { + nvtx::enable(); + setNvtxCallbacks(subscriber, /*enable=*/true); + } +} + +void CuptiProfiler::CuptiProfilerPimpl::doFlush() { + // cuptiActivityFlushAll returns the activity records associated with all + // contexts/streams. + // This is a blocking call but it doesn’t issue any CUDA synchronization calls + // implicitly thus it’s not guaranteed that all activities are completed on + // the underlying devices. + // We do an "opportunistic" synchronization here to try to ensure that all + // activities are completed on the current context. + // If the current context is not set, we don't do any synchronization. + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) { + cuda::ctxSynchronize(); + } + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepUs=*/10, + /*flush=*/[]() { + cupti::activityFlushAll( + /*flag=*/0); + }); + // CUPTI_ACTIVITY_FLAG_FLUSH_FORCED is used to ensure that even incomplete + // activities are flushed so that the next profiling session can start with + // new activities. + cupti::activityFlushAll(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); + // Flush the tensor metric buffer + profiler.pendingGraphPool->flushAll(); +} + +void CuptiProfiler::CuptiProfilerPimpl::doStop() { + if (profiler.pcSamplingEnabled) { + profiler.pcSamplingEnabled = false; + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) + pcSampling.finalize(cuContext); + setResourceCallbacks(subscriber, /*enable=*/false); + cupti::activityDisable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + if (getBoolEnv("TRITON_ENABLE_HW_TRACE", true)) + cupti::activityEnableHWTrace(/*enable=*/0); + } + profiler.periodicFlushingEnabled = false; + profiler.periodicFlushingFormat.clear(); + setGraphCallbacks(subscriber, /*enable=*/false); + setLaunchCallbacks(subscriber, /*enable=*/false); + nvtx::disable(); + setNvtxCallbacks(subscriber, /*enable=*/false); + cupti::unsubscribe(subscriber); + cupti::finalize(); +} + +CuptiProfiler::CuptiProfiler() { + pImpl = std::make_unique(*this); +} + +CuptiProfiler::~CuptiProfiler() = default; + +void CuptiProfiler::doSetMode(const std::vector &modeAndOptions) { + auto mode = modeAndOptions[0]; + if (proton::toLower(mode) == "pcsampling") { + pcSamplingEnabled = true; + } else if (proton::toLower(mode) == "periodic_flushing") { + detail::setPeriodicFlushingMode(periodicFlushingEnabled, + periodicFlushingFormat, modeAndOptions, + "CuptiProfiler"); + } else if (!mode.empty()) { + throw std::invalid_argument("[PROTON] CuptiProfiler: unsupported mode: " + + mode); + } +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/GPUProfiler.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/GPUProfiler.cpp new file mode 100644 index 0000000000..f94e51c80d --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/GPUProfiler.cpp @@ -0,0 +1,280 @@ +#include "Profiler/GPUProfiler.h" +#include "Profiler/Graph.h" + +#include +#include +#include +#include +#include + +namespace proton { +namespace detail { + +namespace { + +struct FlushRange { + Data *data{nullptr}; + size_t minPhaseToFlush{0}; + size_t maxPhaseToFlush{0}; +}; + +std::pair, std::set> +computeFlushRangesAndPeekPhases( + std::map &dataFlushedPhases, + const std::map> + &dataPhases, + const bool peekPendingGraphs) { + std::vector flushRanges; + flushRanges.reserve(dataPhases.size()); + std::set phasesToPeek; + + for (auto [data, phase] : dataPhases) { + if (phase.second == 0) { + continue; + } + + auto flushedPhaseIt = dataFlushedPhases.find(data); + // phase.second at maximum is the current phase, which cannot be a + // "complete" phase yet. So we flush up to phase.second - 1. + const size_t endPhaseToFlush = phase.second - 1; + + size_t minPhaseToFlush = 0; + if (flushedPhaseIt == dataFlushedPhases.end() || + flushedPhaseIt->second == Data::kNoCompletePhase) { + minPhaseToFlush = 0; + } else { + const auto flushedPhase = flushedPhaseIt->second; + if (endPhaseToFlush <= flushedPhase) { + continue; + } + minPhaseToFlush = flushedPhase + 1; + } + + flushRanges.push_back(FlushRange{data, minPhaseToFlush, endPhaseToFlush}); + if (peekPendingGraphs) { + for (size_t p = minPhaseToFlush; p <= endPhaseToFlush; ++p) { + phasesToPeek.insert(p); + } + } + } + + return {std::move(flushRanges), std::move(phasesToPeek)}; +} + +struct PeriodicFlushStats { + uint64_t totalToJsonUs{0}; + uint64_t totalToMsgPackUs{0}; + uint64_t totalJsonWriteUs{0}; + uint64_t totalMsgPackWriteUs{0}; + uint64_t clearUs{0}; + size_t toJsonCalls{0}; + size_t toMsgPackCalls{0}; + size_t jsonWriteCalls{0}; + size_t msgPackWriteCalls{0}; +}; + +void periodicFlushDataPhases(Data &data, + const std::string &periodicFlushingFormat, + size_t minPhaseToFlush, size_t maxPhaseToFlush, + const bool timingEnabled, + PeriodicFlushStats &stats) { + using Clock = std::chrono::steady_clock; + const auto &path = data.getPath(); + + for (auto startPhase = minPhaseToFlush; startPhase <= maxPhaseToFlush; + startPhase++) { + auto pathWithPhase = path + ".part_" + std::to_string(startPhase) + "." + + periodicFlushingFormat; + + if (periodicFlushingFormat == "hatchet" || + periodicFlushingFormat == "chrome_trace") { + std::string jsonStr; + if (timingEnabled) { + const auto t0 = Clock::now(); + jsonStr = data.toJsonString(startPhase); + const auto t1 = Clock::now(); + stats.totalToJsonUs += + std::chrono::duration_cast(t1 - t0) + .count(); + ++stats.toJsonCalls; + } else { + jsonStr = data.toJsonString(startPhase); + } + + if (timingEnabled) { + const auto t0 = Clock::now(); + std::ofstream ofs(pathWithPhase, std::ios::out | std::ios::trunc); + ofs << jsonStr; + ofs.flush(); + const auto t1 = Clock::now(); + stats.totalJsonWriteUs += + std::chrono::duration_cast(t1 - t0) + .count(); + ++stats.jsonWriteCalls; + } else { + std::ofstream ofs(pathWithPhase, std::ios::out | std::ios::trunc); + ofs << jsonStr; + } + } else if (periodicFlushingFormat == "hatchet_msgpack") { + std::vector msgPack; + if (timingEnabled) { + const auto t0 = Clock::now(); + msgPack = data.toMsgPack(startPhase); + const auto t1 = Clock::now(); + stats.totalToMsgPackUs += + std::chrono::duration_cast(t1 - t0) + .count(); + ++stats.toMsgPackCalls; + } else { + msgPack = data.toMsgPack(startPhase); + } + + if (timingEnabled) { + const auto t0 = Clock::now(); + std::ofstream ofs(pathWithPhase, + std::ios::out | std::ios::binary | std::ios::trunc); + ofs.write(reinterpret_cast(msgPack.data()), + msgPack.size()); + ofs.flush(); + const auto t1 = Clock::now(); + stats.totalMsgPackWriteUs += + std::chrono::duration_cast(t1 - t0) + .count(); + ++stats.msgPackWriteCalls; + } else { + std::ofstream ofs(pathWithPhase, + std::ios::out | std::ios::binary | std::ios::trunc); + ofs.write(reinterpret_cast(msgPack.data()), + msgPack.size()); + } + } + } +} + +void periodicClearDataPhases(Data &data, size_t maxPhaseToFlush, + const bool timingEnabled, + PeriodicFlushStats &stats) { + using Clock = std::chrono::steady_clock; + if (!timingEnabled) { + data.clear(maxPhaseToFlush, /*clearUpToPhase=*/true); + return; + } + + const auto t0 = Clock::now(); + data.clear(maxPhaseToFlush, /*clearUpToPhase=*/true); + const auto t1 = Clock::now(); + stats.clearUs = + std::chrono::duration_cast(t1 - t0).count(); +} + +} // namespace + +void setPeriodicFlushingMode(bool &periodicFlushingEnabled, + std::string &periodicFlushingFormat, + const std::vector &modeAndOptions, + const char *profilerName) { + periodicFlushingEnabled = true; + if (modeAndOptions.size() < 2) + periodicFlushingFormat = "hatchet"; + + auto delimiterPos = modeAndOptions[1].find('='); + if (delimiterPos != std::string::npos) { + const std::string key = modeAndOptions[1].substr(0, delimiterPos); + const std::string value = modeAndOptions[1].substr(delimiterPos + 1); + if (key != "format") { + throw std::invalid_argument(std::string("[PROTON] ") + profilerName + + ": unsupported option key: " + key); + } + if (value != "hatchet_msgpack" && value != "chrome_trace" && + value != "hatchet") { + throw std::invalid_argument(std::string("[PROTON] ") + profilerName + + ": unsupported format: " + value); + } + periodicFlushingFormat = value; + } else { + periodicFlushingFormat = "hatchet"; + } +} + +void updateDataPhases(std::map> &dataPhases, + Data *data, size_t phase) { + auto it = dataPhases.find(data); + if (it == dataPhases.end()) { + dataPhases.emplace(data, std::make_pair(phase, phase)); + } else { + it->second.first = std::min(it->second.first, phase); // start phase + it->second.second = std::max(it->second.second, phase); // end phase + } +} + +void flushDataPhasesImpl( + const bool periodicFlushEnabled, const std::string &periodicFlushingFormat, + std::map &dataFlushedPhases, + const std::map> + &dataPhases, + PendingGraphPool *pendingGraphPool) { + static const bool timingEnabled = + getBoolEnv("PROTON_DATA_FLUSH_TIMING", false); + auto [flushRanges, phasesToPeek] = computeFlushRangesAndPeekPhases( + dataFlushedPhases, dataPhases, pendingGraphPool != nullptr); + if (pendingGraphPool) { + using Clock = std::chrono::steady_clock; + uint64_t totalPeekUs = 0; + size_t peekCalls = 0; + for (const auto phase : phasesToPeek) { + if (timingEnabled) { + const auto t0 = Clock::now(); + pendingGraphPool->peek(phase); + const auto t1 = Clock::now(); + totalPeekUs += + std::chrono::duration_cast(t1 - t0) + .count(); + ++peekCalls; + } else { + pendingGraphPool->peek(phase); + } + } + if (timingEnabled && peekCalls > 0) { + auto minPhase = *phasesToPeek.begin(); + auto maxPhase = *phasesToPeek.rbegin(); + std::cerr << "[PROTON] pendingGraphPool peek timing: phases=[" << minPhase + << "," << maxPhase << "] peek_us=" << totalPeekUs + << " peek_calls=" << peekCalls << std::endl; + } + } + + for (const auto &range : flushRanges) { + auto *data = range.data; + const size_t minPhaseToFlush = range.minPhaseToFlush; + const size_t maxPhaseToFlush = range.maxPhaseToFlush; + dataFlushedPhases[data] = maxPhaseToFlush; + data->completePhase(maxPhaseToFlush); + + if (!periodicFlushEnabled) + continue; + + PeriodicFlushStats stats{}; + periodicFlushDataPhases(*data, periodicFlushingFormat, minPhaseToFlush, + maxPhaseToFlush, timingEnabled, stats); + periodicClearDataPhases(*data, maxPhaseToFlush, timingEnabled, stats); + if (timingEnabled) { + std::cerr << "[PROTON] periodicFlush timing: path=" << data->getPath() + << " format=" << periodicFlushingFormat << " phases=[" + << minPhaseToFlush << "," << maxPhaseToFlush + << "] toJsonString_us=" << stats.totalToJsonUs + << " toJsonString_calls=" << stats.toJsonCalls + << " toMsgPack_us=" << stats.totalToMsgPackUs + << " toMsgPack_calls=" << stats.toMsgPackCalls + << " json_write_us=" << stats.totalJsonWriteUs + << " json_write_calls=" << stats.jsonWriteCalls + << " msgpack_write_us=" << stats.totalMsgPackWriteUs + << " msgpack_write_calls=" << stats.msgPackWriteCalls + << " clear_us=" << stats.clearUs << std::endl; + } + } +} + +} // namespace detail +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Graph.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Graph.cpp new file mode 100644 index 0000000000..c411158492 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Graph.cpp @@ -0,0 +1,191 @@ +#include "Profiler/Graph.h" + +#include "Data/Data.h" +#include "Runtime/Runtime.h" + +#include +#include + +namespace proton { + +namespace { +constexpr size_t kMetricWordsPerNode = 2; + +constexpr size_t bytesForNodes(size_t numNodes) { + return numNodes * kMetricWordsPerNode * sizeof(uint64_t); +} + +void emitMetricRecords(MetricBuffer &metricBuffer, uint64_t *hostBasePtr, + const PendingGraphQueue &queue) { + const size_t phase = queue.phase; + const auto &pendingGraphs = queue.pendingGraphs; + const size_t capacityWords = metricBuffer.getCapacity() / sizeof(uint64_t); + size_t wordOffset = queue.startBufferOffset / sizeof(uint64_t); + auto readWord = [&](size_t offset) -> uint64_t { + return hostBasePtr[offset % capacityWords]; + }; + + for (const auto &pendingGraph : pendingGraphs) { + for (size_t i = 0; i < pendingGraph.numNodes; ++i) { + const uint64_t metricId = readWord(wordOffset); + const uint64_t metricValue = readWord(wordOffset + 1); + wordOffset = (wordOffset + kMetricWordsPerNode) % capacityWords; + + auto metricDesc = metricBuffer.getMetricDescriptor(metricId); + const auto &metricName = metricDesc.name; + const auto metricTypeIndex = metricDesc.typeIndex; + + for (auto &[data, entryIds] : pendingGraph.dataToEntryIds) { + const auto entryId = entryIds[i]; + switch (metricTypeIndex) { + case variant_index_v: { + uint64_t typedValue{}; + std::memcpy(&typedValue, &metricValue, sizeof(typedValue)); + data->addMetrics(phase, entryId, + {{metricName, MetricValueType{typedValue}}}); + break; + } + case variant_index_v: { + int64_t typedValue{}; + std::memcpy(&typedValue, &metricValue, sizeof(typedValue)); + data->addMetrics(phase, entryId, + {{metricName, MetricValueType{typedValue}}}); + break; + } + case variant_index_v: { + double typedValue{}; + std::memcpy(&typedValue, &metricValue, sizeof(typedValue)); + data->addMetrics(phase, entryId, + {{metricName, MetricValueType{typedValue}}}); + break; + } + default: + break; + } + } + } + } +} +} // namespace + +void PendingGraphPool::push( + size_t phase, const std::map> &dataToEntryIds, + size_t numNodes) { + const size_t requiredBytes = bytesForNodes(numNodes); + void *device = runtime->getDevice(); + std::shared_ptr slot; + { + std::lock_guard lock(mutex); + auto &devicePool = pool[device]; + auto [poolIt, inserted] = devicePool.try_emplace(phase); + if (inserted) + poolIt->second = std::make_shared(); + slot = poolIt->second; + } + { + std::lock_guard slotLock(slot->mutex); + if (slot->queue == std::nullopt) { + const auto startBufferOffset = + deviceBufferOffset.try_emplace(device, 0).first->second; + slot->queue = PendingGraphQueue(startBufferOffset, phase, device); + } + slot->queue->push(numNodes, dataToEntryIds); + } + { + std::lock_guard lock(mutex); + auto &remainingCapacity = + deviceRemainingCapacity.try_emplace(device, metricBuffer->getCapacity()) + .first->second; + auto &bufferOffset = deviceBufferOffset[device]; + bufferOffset = (bufferOffset + requiredBytes) % metricBuffer->getCapacity(); + remainingCapacity -= requiredBytes; + } +} + +void PendingGraphPool::peek(size_t phase) { + std::vector>> slots; + { + std::lock_guard lock(mutex); + for (auto &[device, devicePool] : pool) { + auto slotIt = devicePool.find(phase); + if (slotIt != devicePool.end()) { + slots.emplace_back(device, slotIt->second); + } + } + } + std::vector> deviceNumNodes; + for (auto &[device, slotPtr] : slots) { + auto numNodes = size_t{0}; + std::lock_guard slotLock(slotPtr->mutex); + if (slotPtr->queue == std::nullopt) + continue; + auto &queue = slotPtr->queue.value(); + numNodes = queue.numNodes; + metricBuffer->peek(static_cast(device), [&](uint8_t *hostPtr) { + emitMetricRecords(*metricBuffer, reinterpret_cast(hostPtr), + queue); + }); + deviceNumNodes.emplace_back(device, numNodes); + } + { + std::lock_guard lock(mutex); + for (auto &[device, numNodes] : deviceNumNodes) { + pool[device].erase(phase); + deviceRemainingCapacity[device] += bytesForNodes(numNodes); + } + } +} + +bool PendingGraphPool::flushIfNeeded(size_t numNodes) { + auto *device = runtime->getDevice(); + const size_t requiredBytes = bytesForNodes(numNodes); + { + std::lock_guard lock(mutex); + auto it = + deviceRemainingCapacity.try_emplace(device, metricBuffer->getCapacity()) + .first; + if (it->second >= requiredBytes) + return false; + } + flushAll(); + return true; +} + +bool PendingGraphPool::flushAll() { + auto poolCopy = decltype(pool){}; + { + std::lock_guard lock(mutex); + if (pool.empty()) + return false; + poolCopy.swap(pool); + } + metricBuffer->flush( + [&](void *device, uint8_t *hostPtr) { + auto deviceIt = poolCopy.find(device); + if (deviceIt == poolCopy.end()) + return; + for (auto &[_, slot] : deviceIt->second) { + std::lock_guard lock(slot->mutex); + if (slot->queue == std::nullopt) + continue; + emitMetricRecords(*metricBuffer, + reinterpret_cast(hostPtr), + *slot->queue); + } + }, + true); + { + std::lock_guard lock(mutex); + for (auto &[device, devicePool] : poolCopy) { + for (auto &[_, slot] : devicePool) { + std::lock_guard slotLock(slot->mutex); + if (slot->queue == std::nullopt) + continue; + deviceRemainingCapacity[device] += bytesForNodes(slot->queue->numNodes); + } + } + } + return true; +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp new file mode 100644 index 0000000000..10a1d29d41 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp @@ -0,0 +1,274 @@ +#include "Profiler/Instrumentation/InstrumentationProfiler.h" +#include "TraceDataIO/CircularLayoutParser.h" + +#include "Runtime/CudaRuntime.h" +#include "Runtime/HipRuntime.h" +#include "Utility/Numeric.h" +#include "Utility/String.h" +#include +#include +#include +#include +#include + +namespace proton { + +constexpr size_t DEFAULT_HOST_BUFFER_SIZE = 64 * 1024 * 1024; // 64MB +constexpr size_t MAX_HOST_BUFFER_SIZE = 4LL * 1024LL * 1024LL * 1024LL; // 4GB + +InstrumentationProfiler::~InstrumentationProfiler() {} + +void InstrumentationProfiler::doStart() { + // Start the instrumentation profiler. +} + +void InstrumentationProfiler::doFlush() { + // Flush the instrumentation profiler. +} + +void InstrumentationProfiler::doStop() { + // Stop the instrumentation profiler. + // FIXME: Also we should ensure the context is valid before releasing the + // memory + if (hostBuffer != nullptr) { + runtime->freeHostBuffer(hostBuffer); + hostBuffer = nullptr; + } + for (auto &[device, deviceStream] : deviceStreams) { + runtime->destroyStream(deviceStream); + } + deviceStreams.clear(); + // Reset mode options + modeOptions.clear(); + // Note that we don't clear function metadata and names here, as they may be + // reused when the profiler is started again. +} + +void InstrumentationProfiler::doSetMode( + const std::vector &modeAndOptions) { + if (modeAndOptions.empty()) { + throw std::runtime_error("Mode cannot be empty"); + } + if (proton::toLower(modeAndOptions[0]) == + proton::toLower(DeviceTraits::name)) { + runtime = &CudaRuntime::instance(); + } else if (proton::toLower(modeAndOptions[0]) == + proton::toLower(DeviceTraits::name)) { + runtime = &HipRuntime::instance(); + } else { + throw std::runtime_error("Unknown device type: " + modeAndOptions[0]); + } + for (size_t i = 1; i < modeAndOptions.size(); ++i) { + auto delimiterPos = modeAndOptions[i].find('='); + if (delimiterPos != std::string::npos) { + std::string key = modeAndOptions[i].substr(0, delimiterPos); + std::string value = modeAndOptions[i].substr(delimiterPos + 1); + modeOptions[key] = value; + } else { + modeOptions[modeAndOptions[i]] = ""; + } + } +} +namespace { + +std::vector +getUnitIdVector(const std::map &modeOptions, + size_t totalUnits) { + std::vector unitIdVector; + if (modeOptions.count("sampling_options") != 0) { + auto &samplingOption = modeOptions.at("sampling_options"); + auto unitIds = proton::split(samplingOption, ","); + for (auto uintId : unitIds) { + if (proton::trim(uintId).empty()) { + continue; + } + uint32_t id = std::stoi(uintId); + unitIdVector.push_back(id); + } + } + if (unitIdVector.empty()) { + unitIdVector.resize(totalUnits); + std::iota(unitIdVector.begin(), unitIdVector.end(), 0); + } + return unitIdVector; +} + +} // namespace + +std::shared_ptr +InstrumentationProfiler::getParserConfig(uint64_t functionId, + size_t bufferSize) const { + // Only support circular layout parser for now, but we will extend the support + // to other parsers in the future + auto config = std::make_shared(); + config->scratchMemSize = + functionMetadata.at(functionId).getScratchMemorySize(); + if (!(modeOptions.count("granularity") == 0 || + modeOptions.at("granularity") == "GRANULARITY.WARP")) { + throw std::runtime_error("Only warp granularity is supported for now"); + } + config->totalUnits = functionMetadata.at(functionId).getNumWarps(); + config->numBlocks = bufferSize / config->scratchMemSize; + config->uidVec = getUnitIdVector(modeOptions, config->totalUnits); + + // Check if the uidVec is valid + for (auto uid : config->uidVec) + if (uid >= config->totalUnits) { + throw std::runtime_error( + "Invalid sampling warp id: " + std::to_string(uid) + ". We have " + + std::to_string(config->totalUnits) + + " warps in total. Please check the proton sampling options."); + } + + config->device = Device(); + config->device.type = runtime->getDeviceType(); + + return config; +} + +void InstrumentationProfiler::initFunctionMetadata( + uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdPairs, + const std::vector> &scopeIdParentPairs, + const std::string &metadataPath) { + if (functionScopeIdNames.count(functionId)) { + throw std::runtime_error( + "Duplicate function id: " + std::to_string(functionId) + + " for function " + functionName); + } + functionNames[functionId] = functionName; + for (auto &pair : scopeIdPairs) { + auto scopeId = pair.first; + auto scopeName = pair.second; + if (functionScopeIdNames[functionId].count(scopeId)) { + throw std::runtime_error( + "Duplicate scope id: " + std::to_string(scopeId) + " for function " + + functionName); + } + functionScopeIdNames[functionId][scopeId] = scopeName; + } + // Synthesize the calling contexts + std::map scopeIdParentMap; + for (auto &pair : scopeIdParentPairs) { + auto scopeId = pair.first; + auto parentId = pair.second; + scopeIdParentMap[scopeId] = parentId; + } + for (auto &[scopeId, name] : functionScopeIdNames[functionId]) { + std::vector contexts = {name}; + auto currentId = scopeId; + while (scopeIdParentMap.count(currentId) > 0) { + auto parentId = scopeIdParentMap[currentId]; + auto parentName = functionScopeIdNames[functionId].at(parentId); + contexts.emplace_back(parentName); + currentId = parentId; + } + std::reverse(contexts.begin(), contexts.end()); + functionScopeIdContexts[functionId][scopeId] = contexts; + } + functionMetadata.emplace(functionId, InstrumentationMetadata(metadataPath)); +} + +void InstrumentationProfiler::enterInstrumentedOp(uint64_t streamId, + uint64_t functionId, + uint8_t *buffer, + size_t size) { + if (!hostBuffer) { + runtime->allocateHostBuffer(&hostBuffer, DEFAULT_HOST_BUFFER_SIZE); + } +} + +void InstrumentationProfiler::exitInstrumentedOp(uint64_t streamId, + uint64_t functionId, + uint8_t *buffer, size_t size) { + if (!buffer || !hostBuffer) + return; + + void *device = runtime->getDevice(); + void *&priorityStream = deviceStreams[device]; + if (!priorityStream) { + priorityStream = runtime->getPriorityStream(); + } + + if (size > MAX_HOST_BUFFER_SIZE) { + throw std::runtime_error( + "Buffer size " + std::to_string(size) + " exceeds the limit " + + std::to_string(MAX_HOST_BUFFER_SIZE) + ", not supported yet in proton"); + } else if (size > DEFAULT_HOST_BUFFER_SIZE) { + runtime->freeHostBuffer(hostBuffer); + auto newSize = nextPowerOfTwo(size); + runtime->allocateHostBuffer(&hostBuffer, newSize); + } + + const auto &functionName = functionNames[functionId]; + enterOp(Scope(functionName)); + + auto config = getParserConfig(functionId, size); + auto circularLayoutConfig = + std::dynamic_pointer_cast(config); + if (!circularLayoutConfig) { + throw std::runtime_error( + "Only circular layout parser is supported for now"); + } + + int64_t timeShiftCost = 0; + if (modeOptions.count("optimizations")) { + auto optimizations = proton::split(modeOptions.at("optimizations"), ","); + if (std::find(optimizations.begin(), optimizations.end(), "time_shift") != + optimizations.end()) + timeShiftCost = getTimeShiftCost(*circularLayoutConfig); + } + auto &scopeIdContexts = functionScopeIdContexts[functionId]; + + runtime->synchronizeStream(reinterpret_cast(streamId)); + runtime->processHostBuffer( + hostBuffer, size, buffer, size, priorityStream, + [&](uint8_t *bufferPtr, size_t size) { + ByteSpan byteSpan(bufferPtr, size); + CircularLayoutParser parser(byteSpan, *circularLayoutConfig); + parser.parse(); + for (auto &blockTrace : parser.getResult()->blockTraces) { + for (auto &trace : blockTrace.traces) { + for (auto &event : trace.profileEvents) { + auto &contexts = scopeIdContexts[event.first->scopeId]; + auto duration = event.second->cycle - event.first->cycle; + auto normalizedDuration = static_cast(duration) / + (circularLayoutConfig->totalUnits * + circularLayoutConfig->numBlocks); + for (auto [data, entry] : dataToEntryMap) { + auto kernelId = entry.id; + entry = data->addOp(entry.phase, kernelId, contexts); + entry.upsertMetric(std::make_unique( + event.first->cycle, event.second->cycle, duration, + normalizedDuration, kernelId, functionName, + blockTrace.blockId, blockTrace.procId, trace.uid, + static_cast(reinterpret_cast(device)), + static_cast(runtime->getDeviceType()), + timeShiftCost, blockTrace.initTime, blockTrace.preFinalTime, + blockTrace.postFinalTime)); + } + } + } + } + }); + + exitOp(Scope(functionName)); +} + +void InstrumentationProfiler::doAddMetrics( + size_t scopeId, const std::map &scalarMetrics, + const std::map &tensorMetrics) { + if (dataToEntryMap.empty()) { + for (auto *data : dataSet) { + data->addMetrics(scopeId, scalarMetrics); + } + } else { + for (auto [data, entry] : dataToEntryMap) { + data->addMetrics(entry.phase, entry.id, scalarMetrics); + } + } + // TODO(Keren): handle tensor metrics by making metricBuffer a member of the + // parent Profiler +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp new file mode 100644 index 0000000000..0bf9147b2c --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp @@ -0,0 +1,28 @@ +#include + +#include "Profiler/Instrumentation/Metadata.h" +#include "nlohmann/json.hpp" + +using json = nlohmann::json; + +namespace proton { + +void InstrumentationMetadata::parse() { + std::ifstream metadataFile(metadataPath); + if (!metadataFile.is_open()) { + throw std::runtime_error("Failed to open metadata file: " + metadataPath); + } + + json metadataJson; + metadataFile >> metadataJson; + + if (metadataJson.contains("profile_scratch_size")) { + scratchMemorySize = metadataJson["profile_scratch_size"].get(); + } + + if (metadataJson.contains("num_warps")) { + numWarps = metadataJson["num_warps"].get(); + } +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/Profiler.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/Profiler.cpp new file mode 100644 index 0000000000..fbb781ec37 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/Profiler.cpp @@ -0,0 +1,7 @@ +#include "Profiler/Profiler.h" + +namespace proton { +thread_local void *Profiler::tensorMetricKernel = nullptr; +thread_local void *Profiler::scalarMetricKernel = nullptr; +thread_local void *Profiler::metricKernelStream = nullptr; +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/mthreads/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp new file mode 100644 index 0000000000..7189f8ffbe --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -0,0 +1,452 @@ +#include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/GPU/HipApi.h" +#include "Driver/GPU/HsaApi.h" +#include "Driver/GPU/RoctracerApi.h" +#include "Runtime/HipRuntime.h" +#include "Utility/Env.h" + +#include "hip/amd_detail/hip_runtime_prof.h" +#include "roctracer/roctracer_ext.h" +#include "roctracer/roctracer_hip.h" +#include "roctracer/roctracer_roctx.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(RoctracerProfiler::instance()); + +namespace { + +class DeviceInfo : public Singleton { +public: + DeviceInfo() = default; + int mapDeviceId(int id) { + // Lazy initialization of device offset by calling hip API. + // Otherwise on nvidia platforms, the HSA call will fail because of no + // available libraries. + std::call_once(deviceOffsetFlag, [this]() { initDeviceOffset(); }); + return id - deviceOffset; + } + +private: + void initDeviceOffset() { + int dc = 0; + auto ret = hip::getDeviceCount(&dc); + hsa::iterateAgents( + [](hsa_agent_t agent, void *data) { + auto &offset = *static_cast(data); + int nodeId; + hsa::agentGetInfo( + agent, + static_cast(HSA_AMD_AGENT_INFO_DRIVER_NODE_ID), + &nodeId); + int deviceType; + hsa::agentGetInfo( + agent, static_cast(HSA_AGENT_INFO_DEVICE), + &deviceType); + if ((nodeId < offset) && (deviceType == HSA_DEVICE_TYPE_GPU)) + offset = nodeId; + + return HSA_STATUS_SUCCESS; + }, + &deviceOffset); + } + + std::once_flag deviceOffsetFlag; + int deviceOffset = 0x7fffffff; +}; + +std::unique_ptr +convertActivityToMetric(const roctracer_record_t *activity) { + std::unique_ptr metric; + switch (activity->kind) { + case kHipVdiCommandTask: + case kHipVdiCommandKernel: { + if (activity->begin_ns < activity->end_ns) { + metric = std::make_unique( + static_cast(activity->begin_ns), + static_cast(activity->end_ns), 1, + static_cast( + DeviceInfo::instance().mapDeviceId(activity->device_id)), + static_cast(DeviceType::HIP), + static_cast(activity->queue_id)); + } + break; + } + default: + break; + } + return metric; +} + +void processActivityKernel( + RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, + RoctracerProfiler::ExternIdToStateMap &externIdToState, + ThreadSafeMap> + &corrIdToIsHipGraph, + std::map> &dataPhases, size_t externId, + const roctracer_record_t *activity) { + if (externId == Scope::DummyScopeId) + return; + bool isGraph = corrIdToIsHipGraph.contain(activity->correlation_id); + auto &state = externIdToState[externId]; + if (!isGraph) { + for (auto [data, entry] : state.dataToEntry) { + if (auto metric = convertActivityToMetric(activity)) { + if (state.isMissingName) { + auto childEntry = data->addOp(entry.phase, entry.id, + {Context(activity->kernel_name)}); + childEntry.upsertMetric(std::move(metric)); + } else { + entry.upsertMetric(std::move(metric)); + } + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } else { + // Graph kernels + // A single graph launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // 1. Graph -> numNodes + // 2. GraphExec -> Graph + // --- Roctracer thread --- + // 3. corrId -> numNodes + for (auto [data, entry] : state.dataToEntry) { + if (auto metric = convertActivityToMetric(activity)) { + auto childEntry = data->addOp(entry.phase, entry.id, + {Context(activity->kernel_name)}); + childEntry.upsertMetric(std::move(metric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } + --state.numNodes; + if (state.numNodes == 0) { + corrIdToExternId.erase(activity->correlation_id); + corrIdToIsHipGraph.erase(activity->correlation_id); + externIdToState.erase(externId); + } + return; +} + +void processActivity( + RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, + RoctracerProfiler::ExternIdToStateMap &externIdToState, + ThreadSafeMap> + &corrIdToIsHipGraph, + std::map> &dataPhases, size_t parentId, + const roctracer_record_t *record) { + switch (record->kind) { + case kHipVdiCommandTask: + case kHipVdiCommandKernel: { + processActivityKernel(corrIdToExternId, externIdToState, corrIdToIsHipGraph, + dataPhases, parentId, record); + break; + } + default: + break; + } +} + +} // namespace + +namespace { + +std::tuple matchKernelCbId(uint32_t cbId) { + bool isRuntimeApi = false; + bool isDriverApi = false; + switch (cbId) { + // TODO: switch to directly subscribe the APIs + case HIP_API_ID_hipStreamBeginCapture: + case HIP_API_ID_hipStreamEndCapture: + case HIP_API_ID_hipExtLaunchKernel: + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + case HIP_API_ID_hipExtModuleLaunchKernel: + case HIP_API_ID_hipHccModuleLaunchKernel: + case HIP_API_ID_hipLaunchCooperativeKernel: + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipLaunchKernel: + case HIP_API_ID_hipModuleLaunchKernel: + case HIP_API_ID_hipGraphLaunch: + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipGraphExecDestroy: + case HIP_API_ID_hipGraphInstantiateWithFlags: + case HIP_API_ID_hipGraphInstantiate: { + isRuntimeApi = true; + break; + } + default: + break; + } + return std::make_pair(isRuntimeApi, isDriverApi); +} + +} // namespace + +struct RoctracerProfiler::RoctracerProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + RoctracerProfilerPimpl(RoctracerProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) { + auto runtime = &HipRuntime::instance(); + profiler.metricBuffer = + std::make_unique(1024 * 1024 * 64, runtime); + } + virtual ~RoctracerProfilerPimpl() = default; + + void doStart() override; + void doFlush() override; + void doStop() override; + + static void apiCallback(uint32_t domain, uint32_t cid, + const void *callbackData, void *arg); + static void activityCallback(const char *begin, const char *end, void *arg); + + ThreadSafeMap> + corrIdToIsHipGraph; + + ThreadSafeMap> + graphExecToGraph; + + ThreadSafeMap> + graphToNumInstances; + + ThreadSafeMap> + streamToCaptureCount; + + ThreadSafeMap> + streamToCapture; +}; + +void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( + uint32_t domain, uint32_t cid, const void *callbackData, void *arg) { + if (domain == ACTIVITY_DOMAIN_HIP_API) { + auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cid); + if (!(isRuntimeAPI || isDriverAPI)) { + return; + } + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + const hip_api_data_t *data = + static_cast(callbackData); + if (data->phase == ACTIVITY_API_PHASE_ENTER) { + // Valid context and outermost level of the kernel launch + // TODO: Get kernel name from hip_api_data_t + threadState.enterOp(Scope("")); + auto &dataToEntry = threadState.dataToEntry; + size_t numInstances = 1; + if (cid == HIP_API_ID_hipGraphLaunch) { + pImpl->corrIdToIsHipGraph[data->correlation_id] = true; + hipGraphExec_t GraphExec = data->args.hipGraphLaunch.graphExec; + numInstances = std::numeric_limits::max(); + bool findGraph = false; + if (pImpl->graphExecToGraph.contain(GraphExec)) { + hipGraph_t Graph = pImpl->graphExecToGraph[GraphExec]; + if (pImpl->graphToNumInstances.contain(Graph)) { + numInstances = pImpl->graphToNumInstances[Graph]; + findGraph = true; + } + } + if (!findGraph) + std::cerr + << "[PROTON] Cannot find graph and it may cause a memory leak." + "To avoid this problem, please start profiling before the " + "graph is created." + << std::endl; + } + auto &scope = threadState.scopeStack.back(); + auto isMissingName = scope.name.empty(); + profiler.correlation.correlate(data->correlation_id, scope.scopeId, + numInstances, isMissingName, dataToEntry); + } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { + switch (cid) { + case HIP_API_ID_hipStreamBeginCapture: { + hipStream_t Stream = data->args.hipStreamBeginCapture.stream; + pImpl->streamToCaptureCount[Stream] = 0; + pImpl->streamToCapture[Stream] = true; + break; + } + case HIP_API_ID_hipStreamEndCapture: { + hipGraph_t Graph = *(data->args.hipStreamEndCapture.pGraph); + hipStream_t Stream = data->args.hipStreamEndCapture.stream; + // How many times did we capture a kernel launch for this stream + uint32_t StreamCaptureCount = pImpl->streamToCaptureCount[Stream]; + pImpl->graphToNumInstances[Graph] = StreamCaptureCount; + pImpl->streamToCapture.erase(Stream); + break; + } + case HIP_API_ID_hipLaunchKernel: { + hipStream_t Stream = data->args.hipLaunchKernel.stream; + if (pImpl->streamToCapture.contain(Stream)) + pImpl->streamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipExtLaunchKernel: { + hipStream_t Stream = data->args.hipExtLaunchKernel.stream; + if (pImpl->streamToCapture.contain(Stream)) + pImpl->streamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipLaunchCooperativeKernel.stream; + if (pImpl->streamToCapture.contain(Stream)) + pImpl->streamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchKernel: { + hipStream_t Stream = data->args.hipModuleLaunchKernel.stream; + if (pImpl->streamToCapture.contain(Stream)) + pImpl->streamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipModuleLaunchCooperativeKernel.stream; + if (pImpl->streamToCapture.contain(Stream)) + pImpl->streamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipGraphInstantiateWithFlags: { + hipGraph_t Graph = data->args.hipGraphInstantiateWithFlags.graph; + hipGraphExec_t GraphExec = + *(data->args.hipGraphInstantiateWithFlags.pGraphExec); + pImpl->graphExecToGraph[GraphExec] = Graph; + break; + } + case HIP_API_ID_hipGraphInstantiate: { + hipGraph_t Graph = data->args.hipGraphInstantiate.graph; + hipGraphExec_t GraphExec = *(data->args.hipGraphInstantiate.pGraphExec); + pImpl->graphExecToGraph[GraphExec] = Graph; + break; + } + } + threadState.exitOp(); + // Track outstanding op for flush + profiler.correlation.submit(data->correlation_id); + } + } else if (domain == ACTIVITY_DOMAIN_ROCTX) { + const roctx_api_data_t *data = + static_cast(callbackData); + if (cid == ROCTX_API_ID_roctxRangePushA) { + threadState.enterScope((data->args).message); + } else if (cid == ROCTX_API_ID_roctxRangePop) { + threadState.exitScope(); + } + } +} + +void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( + const char *begin, const char *end, void *arg) { + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + auto &correlation = profiler.correlation; + + static thread_local std::map dataFlushedPhases; + const roctracer_record_t *record = + reinterpret_cast(begin); + const roctracer_record_t *endRecord = + reinterpret_cast(end); + uint64_t maxCorrelationId = 0; + std::map> dataPhases; + + while (record != endRecord) { + // Log latest completed correlation id. Used to ensure we have flushed all + // data on stop + maxCorrelationId = + std::max(maxCorrelationId, record->correlation_id); + auto externId = Scope::DummyScopeId; + bool hasCorrelation = correlation.corrIdToExternId.withRead( + record->correlation_id, [&](const size_t &value) { externId = value; }); + + if (hasCorrelation) { + // Track correlation ids from the same stream and erase those < + // correlationId + processActivity(correlation.corrIdToExternId, correlation.externIdToState, + pImpl->corrIdToIsHipGraph, dataPhases, externId, record); + } else { + correlation.corrIdToExternId.erase(record->correlation_id); + pImpl->corrIdToIsHipGraph.erase(record->correlation_id); + } + roctracer::getNextRecord(record, &record); + } + correlation.complete(maxCorrelationId); + profiler.flushDataPhases(dataFlushedPhases, dataPhases, + profiler.pendingGraphPool.get()); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { + if (getBoolEnv("TRITON_ENABLE_NVTX", true)) { + roctracer::enableDomainCallback(ACTIVITY_DOMAIN_ROCTX, apiCallback, + nullptr); + } + roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, + nullptr); + // Activity Records + roctracer_properties_t properties{0}; + const auto envBufferSize = + getIntEnv("TRITON_PROFILE_BUFFER_SIZE", 64 * 1024 * 1024); + properties.buffer_size = envBufferSize; + properties.buffer_callback_fun = activityCallback; + roctracer::openPool(&properties); + roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::start(); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doFlush() { + // Implement reliable flushing. + // Wait for all dispatched ops to be reported. + std::ignore = hip::deviceSynchronize(); + // If flushing encounters an activity record still being written, flushing + // stops. Use a subsequent flush when the record has completed being written + // to resume the flush. + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepUs=*/10, /*flush=*/ + []() { roctracer::flushActivity(); }); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStop() { + roctracer::stop(); + roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); + roctracer::disableDomainCallback(ACTIVITY_DOMAIN_ROCTX); + roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::closePool(); +} + +RoctracerProfiler::RoctracerProfiler() { + pImpl = std::make_unique(*this); +} + +RoctracerProfiler::~RoctracerProfiler() = default; + +void RoctracerProfiler::doSetMode( + const std::vector &modeAndOptions) { + auto mode = modeAndOptions[0]; + if (proton::toLower(mode) == "periodic_flushing") { + detail::setPeriodicFlushingMode(periodicFlushingEnabled, + periodicFlushingFormat, modeAndOptions, + "RoctracerProfiler"); + } else if (!mode.empty()) { + throw std::invalid_argument( + "[PROTON] RoctracerProfiler: unsupported mode: " + mode); + } +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Runtime/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Runtime/CMakeLists.txt new file mode 100644 index 0000000000..8c8e93e351 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Runtime/CMakeLists.txt @@ -0,0 +1,4 @@ +add_proton_library(ProtonRuntime + CudaRuntime.cpp + HipRuntime.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Runtime/CudaRuntime.cpp b/third_party/mthreads/proton/csrc/lib/Runtime/CudaRuntime.cpp new file mode 100644 index 0000000000..5cd0d468ee --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Runtime/CudaRuntime.cpp @@ -0,0 +1,120 @@ +#include "Runtime/CudaRuntime.h" + +#include "Driver/GPU/CudaApi.h" +#include +#include + +namespace proton { + +void CudaRuntime::launchKernel(void *kernel, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, void *stream, + void **kernelParams, void **extra) { + cuda::launchKernel(reinterpret_cast(kernel), gridDimX, + gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, + sharedMemBytes, reinterpret_cast(stream), + kernelParams, extra); +} + +void CudaRuntime::memset(void *devicePtr, uint32_t value, size_t size, + void *stream) { + cuda::memsetD32Async(reinterpret_cast(devicePtr), value, + size / sizeof(uint32_t), + reinterpret_cast(stream)); +} + +void CudaRuntime::allocateHostBuffer(uint8_t **buffer, size_t size, + bool mapped) { + if (mapped) { + cuda::memHostAlloc(reinterpret_cast(buffer), size, + CU_MEMHOSTALLOC_DEVICEMAP); + } else { + cuda::memAllocHost(reinterpret_cast(buffer), size); + } +} + +void CudaRuntime::getHostDevicePointer(uint8_t *hostPtr, uint8_t **devicePtr) { + CUdeviceptr devicePtrV; + cuda::memHostGetDevicePointer(&devicePtrV, hostPtr, 0); + *devicePtr = reinterpret_cast(devicePtrV); +} + +void CudaRuntime::freeHostBuffer(uint8_t *buffer) { + cuda::memFreeHost(buffer); +} + +void CudaRuntime::allocateDeviceBuffer(uint8_t **buffer, size_t size) { + CUdeviceptr devicePtr; + cuda::memAlloc(&devicePtr, size); + *buffer = reinterpret_cast(devicePtr); +} + +void CudaRuntime::freeDeviceBuffer(uint8_t *buffer) { + CUdeviceptr devicePtr = reinterpret_cast(buffer); + cuda::memFree(devicePtr); +} + +void CudaRuntime::copyDeviceToHostAsync(void *dst, const void *src, size_t size, + void *stream) { + cuda::memcpyDToHAsync(dst, reinterpret_cast(src), size, + reinterpret_cast(stream)); +} + +void *CudaRuntime::getDevice() { + CUdevice device; + cuda::ctxGetDevice(&device); + return reinterpret_cast(static_cast(device)); +} + +void *CudaRuntime::getPriorityStream() { + CUstream stream; + // TODO: Change priority + int lowestPriority, highestPriority; + cuda::ctxGetStreamPriorityRange(&lowestPriority, &highestPriority); + cuda::streamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, + highestPriority); + return reinterpret_cast(stream); +} + +void CudaRuntime::synchronizeStream(void *stream) { + cuda::streamSynchronize(reinterpret_cast(stream)); +} + +void CudaRuntime::destroyStream(void *stream) { + cuda::streamDestroy(reinterpret_cast(stream)); +} + +void CudaRuntime::synchronizeDevice() { + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) { + cuda::ctxSynchronize(); + } +} + +void CudaRuntime::processHostBuffer( + uint8_t *hostBuffer, size_t hostBufferSize, uint8_t *deviceBuffer, + size_t deviceBufferSize, void *stream, + std::function callback) { + int64_t chunkSize = std::min(hostBufferSize, deviceBufferSize); + int64_t sizeLeftOnDevice = deviceBufferSize; + while (chunkSize > 0) { + cuda::memcpyDToHAsync(reinterpret_cast(hostBuffer), + reinterpret_cast(deviceBuffer), + chunkSize, reinterpret_cast(stream)); + // We should not use synchronization here in general if we want to copy + // buffer while the kernel is running. But for the sake of simplicity, we + // only copy the buffer after the kernel is finished for now. + cuda::streamSynchronize(reinterpret_cast(stream)); + callback(hostBuffer, chunkSize); + hostBuffer += chunkSize; + deviceBuffer += chunkSize; + sizeLeftOnDevice -= chunkSize; + chunkSize = + std::min(static_cast(hostBufferSize), sizeLeftOnDevice); + } +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Runtime/HipRuntime.cpp b/third_party/mthreads/proton/csrc/lib/Runtime/HipRuntime.cpp new file mode 100644 index 0000000000..2c83bdd148 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Runtime/HipRuntime.cpp @@ -0,0 +1,113 @@ +#include "Runtime/HipRuntime.h" + +#include "Driver/GPU/HipApi.h" +#include +#include + +namespace proton { + +void HipRuntime::launchKernel(void *kernel, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, void *stream, + void **kernelParams, void **extra) { + auto status = hip::launchKernel( + reinterpret_cast(kernel), gridDimX, gridDimY, gridDimZ, + blockDimX, blockDimY, blockDimZ, sharedMemBytes, + reinterpret_cast(stream), kernelParams, extra); + (void)status; +} + +void HipRuntime::memset(void *devicePtr, uint32_t value, size_t size, + void *stream) { + auto status = hip::memsetD32Async( + reinterpret_cast(devicePtr), value, + size / sizeof(uint32_t), reinterpret_cast(stream)); + (void)status; +} + +void HipRuntime::allocateHostBuffer(uint8_t **buffer, size_t size, + bool mapped) { + if (mapped) { + (void)hip::memHostAlloc(reinterpret_cast(buffer), size, + hipHostAllocMapped); + } else { + (void)hip::memAllocHost(reinterpret_cast(buffer), size); + } +} + +void HipRuntime::getHostDevicePointer(uint8_t *hostPtr, uint8_t **devicePtr) { + hipDeviceptr_t devicePtrV; + (void)hip::memHostGetDevicePointer(&devicePtrV, hostPtr, 0); + *devicePtr = reinterpret_cast(devicePtrV); +} + +void HipRuntime::freeHostBuffer(uint8_t *buffer) { + (void)hip::memFreeHost(buffer); +} + +void HipRuntime::allocateDeviceBuffer(uint8_t **buffer, size_t size) { + hipDeviceptr_t devicePtr; + (void)hip::memAlloc(reinterpret_cast(&devicePtr), size); + *buffer = reinterpret_cast(devicePtr); +} + +void HipRuntime::freeDeviceBuffer(uint8_t *buffer) { + hipDeviceptr_t devicePtr = reinterpret_cast(buffer); + (void)hip::memFree(devicePtr); +} + +void HipRuntime::copyDeviceToHostAsync(void *dst, const void *src, size_t size, + void *stream) { + (void)hip::memcpyDToHAsync( + dst, reinterpret_cast(const_cast(src)), size, + reinterpret_cast(stream)); +} + +void *HipRuntime::getDevice() { + hipDevice_t device; + (void)hip::ctxGetDevice(&device); + return reinterpret_cast(static_cast(device)); +} + +void *HipRuntime::getPriorityStream() { + hipStream_t stream; + int lowestPriority, highestPriority; + (void)hip::ctxGetStreamPriorityRange(&lowestPriority, &highestPriority); + (void)hip::streamCreateWithPriority(&stream, hipStreamNonBlocking, + highestPriority); + return reinterpret_cast(stream); +} + +void HipRuntime::synchronizeStream(void *stream) { + (void)hip::streamSynchronize(reinterpret_cast(stream)); +} + +void HipRuntime::synchronizeDevice() { (void)hip::deviceSynchronize(); } + +void HipRuntime::destroyStream(void *stream) { + (void)hip::streamDestroy(reinterpret_cast(stream)); +} + +void HipRuntime::processHostBuffer( + uint8_t *hostBuffer, size_t hostBufferSize, uint8_t *deviceBuffer, + size_t deviceBufferSize, void *stream, + std::function callback) { + int64_t chunkSize = std::min(hostBufferSize, deviceBufferSize); + int64_t sizeLeftOnDevice = deviceBufferSize; + while (chunkSize > 0) { + (void)hip::memcpyDToHAsync( + reinterpret_cast(hostBuffer), + reinterpret_cast(deviceBuffer), chunkSize, + reinterpret_cast(stream)); + (void)hip::streamSynchronize(reinterpret_cast(stream)); + callback(hostBuffer, chunkSize); + hostBuffer += chunkSize; + deviceBuffer += chunkSize; + sizeLeftOnDevice -= chunkSize; + chunkSize = + std::min(static_cast(hostBufferSize), sizeLeftOnDevice); + } +} +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Session/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Session/CMakeLists.txt new file mode 100644 index 0000000000..f84eb610a0 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Session/CMakeLists.txt @@ -0,0 +1,3 @@ +add_proton_library(ProtonSession + Session.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Session/Session.cpp b/third_party/mthreads/proton/csrc/lib/Session/Session.cpp new file mode 100644 index 0000000000..2439d57e40 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Session/Session.cpp @@ -0,0 +1,361 @@ +#include "Session/Session.h" +#include "Context/Python.h" +#include "Context/Shadow.h" +#include "Data/TraceData.h" +#include "Data/TreeData.h" +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Profiler/Instrumentation/InstrumentationProfiler.h" +#include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Utility/String.h" + +namespace proton { + +namespace { + +Profiler *makeProfiler(const std::string &name) { + if (proton::toLower(name) == "cupti") { + return &CuptiProfiler::instance(); + } else if (proton::toLower(name) == "roctracer") { + return &RoctracerProfiler::instance(); + } else if (proton::toLower(name) == "instrumentation") { + return &InstrumentationProfiler::instance(); + } + throw std::runtime_error("Unknown profiler: " + name); +} + +std::unique_ptr makeData(const std::string &dataName, + const std::string &path, + ContextSource *contextSource) { + if (toLower(dataName) == "tree") { + return std::make_unique(path, contextSource); + } else if (toLower(dataName) == "trace") { + return std::make_unique(path, contextSource); + } + throw std::runtime_error("Unknown data: " + dataName); +} + +std::unique_ptr +makeContextSource(const std::string &contextSourceName) { + if (toLower(contextSourceName) == "shadow") { + return std::make_unique(); + } else if (toLower(contextSourceName) == "python") { + return std::make_unique(); + } + throw std::runtime_error("Unknown context source: " + contextSourceName); +} + +void throwIfSessionNotInitialized( + const std::map> &sessions, + size_t sessionId) { + if (!sessions.count(sessionId)) { + throw std::runtime_error("Session has not been initialized: " + + std::to_string(sessionId)); + } +} + +} // namespace + +void Session::activate() { + profiler->start(); + profiler->registerData(data.get()); +} + +void Session::deactivate(bool flushing) { + if (flushing) + profiler->flush(); + profiler->unregisterData(data.get()); +} + +void Session::finalize(const std::string &outputFormat) { + profiler->flush(); + profiler->stop(); + data->dump(outputFormat); +} + +size_t Session::getContextDepth() { return contextSource->getDepth(); } + +Profiler *SessionManager::validateAndSetProfilerMode(Profiler *profiler, + const std::string &mode) { + std::vector modeAndOptions = proton::split(mode, ":"); + for (auto &[id, session] : sessions) { + if (session->getProfiler() == profiler && + session->getProfiler()->getMode() != modeAndOptions) { + throw std::runtime_error("Cannot add a session with the same profiler " + "but a different mode than existing sessions"); + } + } + return profiler->setMode(modeAndOptions); +} + +std::unique_ptr SessionManager::makeSession( + size_t id, const std::string &path, const std::string &profilerName, + const std::string &contextSourceName, const std::string &dataName, + const std::string &mode) { + auto *profiler = makeProfiler(profilerName); + profiler = validateAndSetProfilerMode(profiler, mode); + auto contextSource = makeContextSource(contextSourceName); + auto data = makeData(dataName, path, contextSource.get()); + auto *session = new Session(id, path, profiler, std::move(contextSource), + std::move(data)); + return std::unique_ptr(session); +} + +Session *SessionManager::getSessionOrThrow(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); + return sessions[sessionId].get(); +} + +void SessionManager::activateSession(size_t sessionId) { + std::lock_guard lock(mutex); + activateSessionImpl(sessionId); +} + +void SessionManager::activateAllSessions() { + std::lock_guard lock(mutex); + for (auto iter : sessionActive) { + activateSessionImpl(iter.first); + } +} + +void SessionManager::deactivateSession(size_t sessionId, bool flushing) { + std::lock_guard lock(mutex); + deActivateSessionImpl(sessionId, flushing); +} + +void SessionManager::deactivateAllSessions(bool flushing) { + std::lock_guard lock(mutex); + for (auto iter : sessionActive) { + deActivateSessionImpl(iter.first, flushing); + } +} + +void SessionManager::activateSessionImpl(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); + if (sessionActive[sessionId]) + return; + sessionActive[sessionId] = true; + sessions[sessionId]->activate(); + registerInterface(sessionId, scopeInterfaceCounts); + registerInterface(sessionId, opInterfaceCounts); + registerInterface(sessionId, + instrumentationInterfaceCounts); + registerInterface(sessionId, contextSourceCounts); + registerInterface(sessionId, metricInterfaceCounts); +} + +void SessionManager::deActivateSessionImpl(size_t sessionId, bool flushing) { + throwIfSessionNotInitialized(sessions, sessionId); + if (!sessionActive[sessionId]) { + return; + } + sessionActive[sessionId] = false; + sessions[sessionId]->deactivate(flushing); + unregisterInterface(sessionId, scopeInterfaceCounts); + unregisterInterface(sessionId, opInterfaceCounts); + unregisterInterface(sessionId, + instrumentationInterfaceCounts); + unregisterInterface(sessionId, contextSourceCounts); + unregisterInterface(sessionId, metricInterfaceCounts); +} + +void SessionManager::removeSession(size_t sessionId) { + if (!hasSession(sessionId)) { + return; + } + // Context source can be safely cleared here but not deactivation. + // Context source of each session is still sort of active after deactivation, + // For example, if we have + // ```Python + // proton.deactivate_session(session0) + // with proton.scope("A"): + // proton.activate_session(session0) + // ``` + // session0 should be aware of scope "A"'s enter and exit, otherwise the + // context stack will be imbalanced. + sessions[sessionId]->contextSource->clear(); + auto path = sessions[sessionId]->path; + sessionPaths.erase(path); + sessionActive.erase(sessionId); + sessions.erase(sessionId); +} + +size_t SessionManager::addSession(const std::string &path, + const std::string &profilerName, + const std::string &contextSourceName, + const std::string &dataName, + const std::string &mode) { + std::lock_guard lock(mutex); + if (hasSession(path)) { + auto sessionId = getSessionId(path); + activateSessionImpl(sessionId); + return sessionId; + } + auto sessionId = nextSessionId++; + auto newSession = makeSession(sessionId, path, profilerName, + contextSourceName, dataName, mode); + sessionPaths[path] = sessionId; + sessions[sessionId] = std::move(newSession); + return sessionId; +} + +void SessionManager::finalizeSession(size_t sessionId, + const std::string &outputFormat) { + std::lock_guard lock(mutex); + if (!hasSession(sessionId)) { + return; + } + deActivateSessionImpl(sessionId, /*flushing=*/true); + sessions[sessionId]->finalize(outputFormat); + removeSession(sessionId); +} + +void SessionManager::finalizeAllSessions(const std::string &outputFormat) { + std::lock_guard lock(mutex); + auto sessionIds = std::vector{}; + for (auto &[sessionId, session] : sessions) { + deActivateSessionImpl(sessionId, /*flushing=*/true); + session->finalize(outputFormat); + sessionIds.push_back(sessionId); + } + for (auto sessionId : sessionIds) { + removeSession(sessionId); + } +} + +void SessionManager::enterScope(const Scope &scope) { + std::lock_guard lock(mutex); + executeInterface(scopeInterfaceCounts, [&](auto *scopeInterface) { + scopeInterface->enterScope(scope); + }); +} + +void SessionManager::exitScope(const Scope &scope) { + std::lock_guard lock(mutex); + executeInterface( + scopeInterfaceCounts, + [&](auto *scopeInterface) { scopeInterface->exitScope(scope); }, + /*isReversed=*/true); +} + +void SessionManager::enterOp(const Scope &scope) { + std::lock_guard lock(mutex); + executeInterface(opInterfaceCounts, + [&](auto *opInterface) { opInterface->enterOp(scope); }); +} + +void SessionManager::exitOp(const Scope &scope) { + std::lock_guard lock(mutex); + executeInterface( + opInterfaceCounts, [&](auto *opInterface) { opInterface->exitOp(scope); }, + /*isReversed=*/true); +} + +void SessionManager::initFunctionMetadata( + uint64_t functionId, const std::string &functionName, + const std::vector> &scopeIdNames, + const std::vector> &scopeIdParents, + const std::string &metadataPath) { + std::lock_guard lock(mutex); + executeInterface(instrumentationInterfaceCounts, + [&](auto *instrumentationInterface) { + instrumentationInterface->initFunctionMetadata( + functionId, functionName, scopeIdNames, scopeIdParents, + metadataPath); + }); +} + +void SessionManager::enterInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) { + std::lock_guard lock(mutex); + executeInterface(instrumentationInterfaceCounts, + [&](auto *instrumentationInterface) { + instrumentationInterface->enterInstrumentedOp( + streamId, functionId, buffer, size); + }); +} + +void SessionManager::exitInstrumentedOp(uint64_t streamId, uint64_t functionId, + uint8_t *buffer, size_t size) { + std::lock_guard lock(mutex); + executeInterface( + instrumentationInterfaceCounts, + [&](auto *instrumentationInterface) { + instrumentationInterface->exitInstrumentedOp(streamId, functionId, + buffer, size); + }, + /*isReversed=*/true); +} + +void SessionManager::addMetrics( + size_t scopeId, const std::map &scalarMetrics, + const std::map &tensorMetrics) { + std::lock_guard lock(mutex); + executeInterface(metricInterfaceCounts, [&](auto *metricInterface) { + metricInterface->addMetrics(scopeId, scalarMetrics, tensorMetrics); + }); +} + +void SessionManager::setMetricKernels(void *tensorMetricKernel, + void *scalarMetricKernel, void *stream) { + std::lock_guard lock(mutex); + executeInterface(metricInterfaceCounts, [&](auto *metricInterface) { + metricInterface->setMetricKernels(tensorMetricKernel, scalarMetricKernel, + stream); + }); +} + +void SessionManager::setState(std::optional context) { + std::lock_guard lock(mutex); + for (auto iter : contextSourceCounts) { + auto [contextSource, count] = iter; + if (count > 0) { + contextSource->setState(context); + } + } +} + +size_t SessionManager::getContextDepth(size_t sessionId) { + std::lock_guard lock(mutex); + return getSessionOrThrow(sessionId)->getContextDepth(); +} + +std::vector SessionManager::getDataMsgPack(size_t sessionId, + size_t phase) { + std::lock_guard lock(mutex); + auto *session = getSessionOrThrow(sessionId); + auto *treeData = dynamic_cast(session->data.get()); + if (!treeData) { + throw std::runtime_error( + "Only TreeData is supported for getData() for now"); + } + return treeData->toMsgPack(phase); +} + +std::string SessionManager::getData(size_t sessionId, size_t phase) { + std::lock_guard lock(mutex); + auto *session = getSessionOrThrow(sessionId); + auto *treeData = dynamic_cast(session->data.get()); + if (!treeData) { + throw std::runtime_error( + "Only TreeData is supported for getData() for now"); + } + return treeData->toJsonString(phase); +} + +void SessionManager::clearData(size_t sessionId, size_t phase, + bool clearUpToPhase) { + std::lock_guard lock(mutex); + getSessionOrThrow(sessionId)->data->clear(phase, clearUpToPhase); +} + +size_t SessionManager::advanceDataPhase(size_t sessionId) { + std::lock_guard lock(mutex); + return getSessionOrThrow(sessionId)->data->advancePhase(); +} + +bool SessionManager::isDataPhaseComplete(size_t sessionId, size_t phase) { + std::lock_guard lock(mutex); + return getSessionOrThrow(sessionId)->data->getPhaseInfo().isComplete(phase); +} + +} // namespace proton diff --git a/third_party/mthreads/proton/csrc/lib/Utility/CMakeLists.txt b/third_party/mthreads/proton/csrc/lib/Utility/CMakeLists.txt new file mode 100644 index 0000000000..377664655b --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Utility/CMakeLists.txt @@ -0,0 +1,3 @@ +add_proton_library(ProtonUtility + MsgPackWriter.cpp +) diff --git a/third_party/mthreads/proton/csrc/lib/Utility/MsgPackWriter.cpp b/third_party/mthreads/proton/csrc/lib/Utility/MsgPackWriter.cpp new file mode 100644 index 0000000000..65980c9746 --- /dev/null +++ b/third_party/mthreads/proton/csrc/lib/Utility/MsgPackWriter.cpp @@ -0,0 +1,118 @@ +#include "Utility/MsgPackWriter.h" + +#include +#include +#include +#include + +namespace proton { +namespace { + +template void writeBE(std::vector &out, T value) { + using U = std::make_unsigned_t; + U u = static_cast(value); + for (int i = sizeof(U) - 1; i >= 0; --i) { + out.push_back(static_cast((u >> (i * 8)) & 0xff)); + } +} + +} // namespace + +void MsgPackWriter::reserve(size_t bytes) { out.reserve(bytes); } + +std::vector MsgPackWriter::take() && { return std::move(out); } + +void MsgPackWriter::packNil() { out.push_back(0xc0); } + +void MsgPackWriter::packBool(bool value) { out.push_back(value ? 0xc3 : 0xc2); } + +void MsgPackWriter::packUInt(uint64_t value) { + if (value <= 0x7f) { + out.push_back(static_cast(value)); + } else if (value <= 0xff) { + out.push_back(0xcc); + out.push_back(static_cast(value)); + } else if (value <= 0xffff) { + out.push_back(0xcd); + writeBE(out, static_cast(value)); + } else if (value <= 0xffffffffull) { + out.push_back(0xce); + writeBE(out, static_cast(value)); + } else { + out.push_back(0xcf); + writeBE(out, static_cast(value)); + } +} + +void MsgPackWriter::packInt(int64_t value) { + if (value >= 0) { + packUInt(static_cast(value)); + return; + } + if (value >= -32) { + out.push_back(static_cast(0xe0 | (value + 32))); + } else if (value >= std::numeric_limits::min()) { + out.push_back(0xd0); + out.push_back(static_cast(static_cast(value))); + } else if (value >= std::numeric_limits::min()) { + out.push_back(0xd1); + writeBE(out, static_cast(value)); + } else if (value >= std::numeric_limits::min()) { + out.push_back(0xd2); + writeBE(out, static_cast(value)); + } else { + out.push_back(0xd3); + writeBE(out, static_cast(value)); + } +} + +void MsgPackWriter::packDouble(double value) { + out.push_back(0xcb); + uint64_t bits{}; + static_assert(sizeof(bits) == sizeof(value)); + std::memcpy(&bits, &value, sizeof(bits)); + writeBE(out, bits); +} + +void MsgPackWriter::packStr(std::string_view value) { + const auto size = static_cast(value.size()); + if (size <= 31) { + out.push_back(static_cast(0xa0 | size)); + } else if (size <= 0xff) { + out.push_back(0xd9); + out.push_back(static_cast(size)); + } else if (size <= 0xffff) { + out.push_back(0xda); + writeBE(out, static_cast(size)); + } else { + out.push_back(0xdb); + writeBE(out, static_cast(size)); + } + out.insert(out.end(), value.begin(), value.end()); +} + +void MsgPackWriter::packArray(uint32_t size) { + if (size <= 15) { + out.push_back(static_cast(0x90 | size)); + } else if (size <= 0xffff) { + out.push_back(0xdc); + writeBE(out, static_cast(size)); + } else { + out.push_back(0xdd); + writeBE(out, static_cast(size)); + } +} + +void MsgPackWriter::packMap(uint32_t size) { + if (size <= 15) { + out.push_back(static_cast(0x80 | size)); + } else if (size <= 0xffff) { + out.push_back(0xde); + writeBE(out, static_cast(size)); + } else { + out.push_back(0xdf); + writeBE(out, static_cast(size)); + } +} + +} // namespace proton diff --git a/third_party/mthreads/proton/proton/__init__.py b/third_party/mthreads/proton/proton/__init__.py new file mode 100644 index 0000000000..2ea3e401c9 --- /dev/null +++ b/third_party/mthreads/proton/proton/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa +from .scope import scope, cpu_timed_scope, enter_scope, exit_scope +from .state import state, enter_state, exit_state, metadata_state +from .profile import ( + start, + activate, + deactivate, + finalize, + profile, + DEFAULT_PROFILE_NAME, +) +from . import context, specs, mode, data diff --git a/third_party/mthreads/proton/proton/context.py b/third_party/mthreads/proton/proton/context.py new file mode 100644 index 0000000000..f7dff1f071 --- /dev/null +++ b/third_party/mthreads/proton/proton/context.py @@ -0,0 +1,18 @@ +from typing import Optional +from triton._C.libproton import proton as libproton +from .flags import flags + + +def depth(session: Optional[int] = 0) -> Optional[int]: + """ + Get the depth of the context. + + Args: + session (int): The session ID of the profiling session. Defaults to 0. + + Returns: + depth (int or None): The depth of the context. If profiling is off, returns None. + """ + if not flags.profiling_on: + return None + return libproton.get_context_depth(session) diff --git a/third_party/mthreads/proton/proton/data.py b/third_party/mthreads/proton/proton/data.py new file mode 100644 index 0000000000..db4b26de7b --- /dev/null +++ b/third_party/mthreads/proton/proton/data.py @@ -0,0 +1,96 @@ +from typing import Optional +from triton._C.libproton import proton as libproton # type: ignore +import json as json +from .flags import flags + + +def get(session: Optional[int] = 0, phase: int = 0): + """ + Retrieves profiling data for a given session. + + Args: + session (Optional[int]): The session ID of the profiling session, or None if profiling is inactive. + Returns: + str: The profiling data in JSON format. + """ + if session is None: + return None + if flags.command_line and session != 0: + raise ValueError("Only one session can be retrieved when running from the command line.") + return json.loads(libproton.get_data(session, phase)) + + +def get_msgpack(session: Optional[int] = 0, phase: int = 0): + """ + Retrieves profiling data for a given session encoded with MessagePack. + + Args: + session (Optional[int]): The session ID of the profiling session, or None if profiling is inactive. + + Returns: + bytes: The profiling data encoded with MessagePack. + """ + if session is None: + return None + if flags.command_line and session != 0: + raise ValueError("Only one session can be retrieved when running from the command line.") + return libproton.get_data_msgpack(session, phase) + + +def advance_phase(session: Optional[int] = 0) -> Optional[int]: + """ + Advances the profiling phase for a given session. + + Args: + session (Optional[int]): The session ID of the profiling session, or None if profiling is inactive. + + Returns: + Optional[int]: The next phase number after advancing. + """ + if session is None: + return None + if flags.command_line and session != 0: + raise ValueError("Only one session can advance phase when running from the command line.") + return libproton.advance_data_phase(session) + + +def is_phase_complete(session: Optional[int] = 0, phase: int = 0) -> bool: + """ + Checks if the profiling data for a given session and phase is complete. + + A "complete" phase is safe to read/clear because all device-side records for + the phase have been flushed to the host and the phase will no longer receive + new records. + + Args: + session (Optional[int]): The session ID of the profiling session, or None if profiling is inactive. + phase (int): The phase number to check. Defaults to 0. + + Returns: + bool: True if the phase data is complete, False otherwise. + """ + if session is None: + return False + if flags.command_line and session != 0: + raise ValueError("Only one session can check phase completion status when running from the command line.") + return libproton.is_data_phase_complete(session, phase) + + +def clear( + session: Optional[int] = 0, + phase: int = 0, + clear_up_to_phase: bool = False, +) -> None: + """ + Clears profiling data for a given session. + + Args: + session (Optional[int]): The session ID of the profiling session, or None if profiling is inactive. + phase (int): The phase number to clear. Defaults to 0. + clear_up_to_phase (bool): If True, clear all phases up to and including `phase`. + """ + if session is None: + return + if flags.command_line and session != 0: + raise ValueError("Only one session can be cleared when running from the command line.") + libproton.clear_data(session, phase, clear_up_to_phase) diff --git a/third_party/mthreads/proton/proton/flags.py b/third_party/mthreads/proton/proton/flags.py new file mode 100644 index 0000000000..bef7621014 --- /dev/null +++ b/third_party/mthreads/proton/proton/flags.py @@ -0,0 +1,28 @@ +""" +Centralized, process-local flags with a minimal interface (no environment variables). + +Usage: + from triton.profiler.flags import flags + + # Toggle + flags.profiling_on = True + flags.instrumentation_on = False + + # Check + if flags.command_line: + ... +""" +from dataclasses import dataclass + + +@dataclass +class ProfilerFlags: + # Whether profiling is enabled. Default is False. + profiling_on: bool = False + # Whether instrumentation is enabled. Default is False. + instrumentation_on: bool = False + # Whether the script is run from the command line. Default is False. + command_line: bool = False + + +flags = ProfilerFlags() diff --git a/third_party/mthreads/proton/proton/hooks/__init__.py b/third_party/mthreads/proton/proton/hooks/__init__.py new file mode 100644 index 0000000000..5ba3ff5395 --- /dev/null +++ b/third_party/mthreads/proton/proton/hooks/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa +from .hook import HookManager +from .instrumentation import InstrumentationHook +from .launch import LaunchHook diff --git a/third_party/mthreads/proton/proton/hooks/hook.py b/third_party/mthreads/proton/proton/hooks/hook.py new file mode 100644 index 0000000000..a672722a12 --- /dev/null +++ b/third_party/mthreads/proton/proton/hooks/hook.py @@ -0,0 +1,128 @@ +from triton.compiler import LazyDict +from abc import abstractmethod +from typing import Dict, Any, Optional +from collections import defaultdict +import triton.knobs as knobs + + +class Hook: + priority: int = 0 + + @abstractmethod + def init_handle(self, module: Any, function: Any, name: str, metadata_group: Dict[str, str], + hash: str) -> None: # noqa: D401 + raise NotImplementedError + + @abstractmethod + def enter(self, metadata: LazyDict) -> None: + raise NotImplementedError + + @abstractmethod + def exit(self, metadata: LazyDict) -> None: + raise NotImplementedError + + @abstractmethod + def activate(self) -> None: + raise NotImplementedError + + @abstractmethod + def deactivate(self) -> None: + raise NotImplementedError + + +class HookManager: + # active hooks + active_hooks: list[Hook] = [] + # session_id -> (hook_type -> active) + session_hooks: Dict[int, Dict[Hook, bool]] = defaultdict(lambda: defaultdict(bool)) + + @staticmethod + def init_handle(module: Any, function: Any, name: str, metadata_group: Dict[str, str], hash: str) -> None: + for hook in HookManager.active_hooks: + hook.init_handle(module, function, name, metadata_group, hash) + + @staticmethod + def enter(metadata: LazyDict) -> None: + for hook in HookManager.active_hooks: + hook.enter(metadata) + + @staticmethod + def exit(metadata: LazyDict) -> None: + # It's important to reverse the order of hooks so that we keep the first in last out order + for hook in reversed(HookManager.active_hooks): + hook.exit(metadata) + + @staticmethod + def activate(session: Optional[int] = None) -> None: + if session is None: + sessions = HookManager.session_hooks.keys() + else: + sessions = [session] + + for session in sessions: + for hook in HookManager.session_hooks[session]: + if hook not in HookManager.active_hooks: + hook.activate() + HookManager.active_hooks.append(hook) + HookManager.session_hooks[session][hook] = True + # Sort active_hooks by priority + HookManager.active_hooks.sort(key=lambda x: x.priority, reverse=True) + + @staticmethod + def deactivate(session: Optional[int] = None) -> None: + if session is None: + sessions = HookManager.session_hooks.keys() + else: + sessions = [session] + + deactivated_hooks = set() + for session in sessions: + for hook in HookManager.session_hooks[session]: + if hook in HookManager.active_hooks: + deactivated_hooks.add(hook) + HookManager.session_hooks[session][hook] = False + + # Check if any other sessions rely on this hook + for hook in deactivated_hooks: + if not any(session_hooks[hook] for session_hooks in HookManager.session_hooks.values()): + hook.deactivate() + HookManager.active_hooks.remove(hook) + + @staticmethod + def register(hook: Hook, session: int) -> None: + HookManager.session_hooks[session][hook] = True + if hook not in HookManager.active_hooks: + hook.activate() + HookManager.active_hooks.append(hook) + # Sort active_hooks by priority + HookManager.active_hooks.sort(key=lambda x: x.priority, reverse=True) + + # Register the heads + knobs.runtime.kernel_load_end_hook.add(HookManager.init_handle) + knobs.runtime.launch_enter_hook.add(HookManager.enter) + knobs.runtime.launch_exit_hook.add(HookManager.exit) + + @staticmethod + def unregister(session: Optional[int] = None) -> None: + if session is not None and session not in HookManager.session_hooks: + return + + if session is None: + for hook in HookManager.active_hooks: + hook.deactivate() + HookManager.active_hooks.clear() + HookManager.session_hooks.clear() + else: + popped_hooks = HookManager.session_hooks.pop(session) + # Deactivate hooks that are not used by any other session + for hook, active in popped_hooks.items(): + if not active: + continue + if not any(session_hooks[hook] for session_hooks in HookManager.session_hooks.values()): + hook.deactivate() + HookManager.active_hooks.remove(hook) + # Unregister the heads + if not HookManager.active_hooks: + knobs.runtime.kernel_load_end_hook.remove(HookManager.init_handle) + knobs.runtime.launch_enter_hook.remove(HookManager.enter) + knobs.runtime.launch_exit_hook.remove(HookManager.exit) diff --git a/third_party/mthreads/proton/proton/hooks/instrumentation.py b/third_party/mthreads/proton/proton/hooks/instrumentation.py new file mode 100644 index 0000000000..aac27ad188 --- /dev/null +++ b/third_party/mthreads/proton/proton/hooks/instrumentation.py @@ -0,0 +1,348 @@ +from typing import Dict, Optional, Union, Any + +import triton +from triton._C.libtriton import ir as triton_ir +from triton._C.libtriton import proton as triton_proton +from triton._C.libtriton import amd as triton_amd +from triton._C.libtriton import nvidia as triton_nvidia +from triton._C.libtriton import passes as triton_passes +from triton._C.libproton import proton as libproton +from triton.compiler import LazyDict +from triton.runtime._allocation import set_profile_allocator, NullAllocator +from triton.backends import backends + +from .hook import Hook +from ..flags import flags +from .. import mode + +# TODO(fywkevin): add support for major.minor +VERSION = 1 + + +class CudaAllocator: + + def __init__(self, instrumentation_hook): + self.instrumentation_hook = instrumentation_hook + + def __call__(self, size: int, alignment: int, stream: Optional[int]): + if alignment != self.instrumentation_hook.profile_buffer_alignment: + raise RuntimeError( + f"Alignment mismatch: {alignment} != {self.instrumentation_hook.profile_buffer_alignment}") + aligned_size = (size + alignment - 1) // alignment * alignment + # Note: profile_buffer_size may be smaller than the aligned size if the kernel launches many blocks + # and the host CPU cannot store all profiling data in memory. This streaming mode is not yet implemented. + # In the future, we should support copying data incrementally from device to host to enable + # more efficient profiling data processing, rather than relying solely on post-processing. + aligned_size = max(aligned_size, self.instrumentation_hook.profile_buffer_size) + + # Create the buffer + import torch + buffer = torch.empty((aligned_size, ), dtype=torch.uint8, device="cuda") + self.instrumentation_hook.buffer = buffer + return buffer + + +class Instrumentation: + + def __init__(self, ir_map: Dict[str, Any]): + self.manager = ir_map + + def register(self, ir: str, func): + if ir in self.manager: + raise RuntimeError(f"IR already registered: {ir}") + self.manager[ir] = func + + def patch(self, ir: str, pm, context): + self.load_dialects(context) + if ir in self.manager: + self.manager[ir](pm) + + def load_dialects(self, ctx): + triton_proton.load_dialects(ctx) + + +def _interpret_mode(mode_obj: Union[str, mode.InstrumentationMode]) -> mode.InstrumentationMode: + if isinstance(mode_obj, mode.InstrumentationMode): + return mode_obj + elif not mode_obj: + mode_obj = "default" + + parts = mode_obj.split(":") + mode_name = parts[0] + opts: Dict[str, str] = {} + for opt in parts[1:]: + if "=" in opt: + key, val = opt.split("=", 1) + opts[key] = val + else: + raise ValueError(f"Malformed instrumentation option: '{opt}'") + + # Get option values or empty strings + options = { + "metric_type": opts.get("metric_type", "cycle"), "buffer_type": opts.get("buffer_type", "shared"), + "buffer_strategy": opts.get("buffer_strategy", "circular"), "buffer_size": int(opts.get("buffer_size", "0")), + "granularity": opts.get("granularity", "warp"), "sampling_strategy": opts.get("sampling_strategy", "none"), + "sampling_options": opts.get("sampling_options", ""), "optimizations": opts.get("optimizations", "") + } + + # Helper function to validate and map options to their enum values + def get_option_value(opt_name, mapping): + value = options[opt_name] + if value and value not in mapping: + raise ValueError(f"Unknown {opt_name}: {value}") + return mapping[value] if value else value + + # Look up enum values for each option + options["metric_type"] = get_option_value("metric_type", mode.metric_types) + options["buffer_type"] = get_option_value("buffer_type", mode.buffer_types) + options["buffer_strategy"] = get_option_value("buffer_strategy", mode.buffer_strategies) + options["granularity"] = get_option_value("granularity", mode.granularities) + options["sampling_strategy"] = get_option_value("sampling_strategy", mode.sampling_strategies) + + values = ([value.strip() + for value in options["optimizations"].split(",")] if len(options["optimizations"]) > 0 else []) + for value in values: + if value not in mode.optimizations: + raise ValueError(f"Unknown optimization: {value}") + options["optimizations"] = [mode.optimizations[value] for value in values] + + # Create the appropriate mode instance + if mode_name == "default": + return mode.Default(**options) + elif mode_name == "mma": + return mode.MMA(**options) + else: + raise ValueError(f"Unknown mode: {mode_obj}") + + +def _get_backend_name() -> str: + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "cuda": + return "nvidia" + elif backend == "hip": + return "amd" + else: + raise RuntimeError(f"Unsupported backend: {backend}") + + +class InstrumentationHook(Hook): + priority: int = 0 + # It's important to note that only one instance of the instrumentation hook can be active at a time. + active_count: int = 0 + enable_host_buffer: bool = False + host_buffer: Optional[Any] = None + # FIXME(fywkevin): change to a more reasonable value after we have support for periodic buffer dumping. + profile_buffer_size: int = 1 + profile_buffer_alignment: int = 128 + + def __init__(self, mode_obj: Union[None, str, mode.InstrumentationMode]): + # Mapping of function objects to their scope ID pairs + self.mode: mode.InstrumentationMode = _interpret_mode(mode_obj) + + self.allocator = CudaAllocator(self) + self.buffer = None + self.metadata_path: Dict[Any, Optional[str]] = {} + + def activate(self): + if InstrumentationHook.active_count > 0: + raise RuntimeError("Only one instance of the instrumentation hook can be active at a time.") + + InstrumentationHook.active_count += 1 + + flags.instrumentation_on = True + + device = triton.runtime.driver.active.get_current_device() + max_shared_mem = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"] + backend_name = _get_backend_name() + + def to_llvmir_passes(pm): + is_long_clk = False if mode.Optimize.CLOCK32 in self.mode.optimizations else True + triton_proton.add_convert_proton_to_protongpu(pm, self.mode.metric_type, self.mode.sampling_strategy, + self.mode.sampling_options, self.mode.granularity, + self.mode.buffer_strategy, self.mode.buffer_type, + self.mode.buffer_size, max_shared_mem, + self.profile_buffer_size, self.profile_buffer_alignment, + is_long_clk) + triton_passes.common.add_cse(pm) + + if mode.Optimize.SCHED_STORES in self.mode.optimizations: + triton_proton.add_schedule_buffer_store(pm) + + triton_proton.add_allocate_proton_shared_memory(pm) + + if mode.Optimize.SCHED_BARRIERS in self.mode.optimizations and backend_name == "amd": + triton_proton.add_sched_barriers(pm) + + def to_llvm_passes(pm): + triton_proton.add_allocate_proton_global_scratch_buffer(pm) + if backend_name == "nvidia": + triton_proton.add_convert_proton_nvidia_gpu_to_llvm(pm) + elif backend_name == "amd": + arch = triton.runtime.driver.active.utils.get_device_properties(device)["arch"].split(":")[0] + triton_proton.add_convert_proton_amd_gpu_to_llvm(pm, arch) + + backends[backend_name].compiler.instrumentation = Instrumentation({ + "ttgpuir_to_llvmir": + lambda pm: to_llvmir_passes(pm), + "llvmir_to_llvm": + lambda pm: to_llvm_passes(pm), + }) + + # Set up the profiling allocator + set_profile_allocator(self.allocator) + + # Set the instrumentation mode + triton.knobs.compilation.instrumentation_mode = str(self.mode) + + def deactivate(self): + if InstrumentationHook.active_count == 0: + return + + InstrumentationHook.active_count -= 1 + + backend_name = _get_backend_name() + + # No instrumentation passes are registered anymore + backends[backend_name].compiler.instrumentation = {} + + # No runtime instrumentation hook is active anymore + flags.instrumentation_on = False + + # Restore the instrumentation mode + triton.knobs.compilation.instrumentation_mode = "" + + # Reset profile allocator + set_profile_allocator(NullAllocator()) + + # Reset host memory for external processing + InstrumentationHook.host_buffer = None + + # Reset the buffer reference + self.buffer = None + + def init_handle(self, module: Any, function: Any, name: str, metadata_group: Dict[str, str], hash: str) -> None: + if not function: + return + + # Find the IR path in metadata + ir_path = next((path for key, path in metadata_group.items() if key.endswith(("ttgir"))), None) + metadata_path = next((path for key, path in metadata_group.items() if key.endswith(("json"))), None) + self.metadata_path[function] = metadata_path + + if ir_path: + context = triton_ir.context() + triton_ir.load_dialects(context) + backend_name = _get_backend_name() + if backend_name == "nvidia": + triton_nvidia.load_dialects(context) + elif backend_name == "amd": + triton_amd.load_dialects(context) + triton_proton.load_dialects(context) + module = triton_ir.parse_mlir_module(ir_path, context) + module.context = context + + scope_id_names = triton_proton.get_scope_id_names(module) + scope_id_parents = triton_proton.get_scope_id_parents(module) + libproton.init_function_metadata(function, name, scope_id_names, scope_id_parents, metadata_path) + else: + raise RuntimeError(f"IR path not found in metadata for function {function}") + + def _data_ptr(self) -> int: + return 0 if self.buffer is None else self.buffer.data_ptr() + + def enter(self, metadata: LazyDict) -> None: + func = metadata.data.get("function") + stream = metadata.data.get("stream") + alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel() + libproton.enter_instrumented_op(stream, func, self._data_ptr(), alloc_size) + if InstrumentationHook.enable_host_buffer: + InstrumentationHook.host_buffer = None + + def exit(self, metadata: LazyDict) -> None: + func = metadata.data.get("function") + stream = metadata.data.get("stream") + alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel() + libproton.exit_instrumented_op(stream, func, self._data_ptr(), alloc_size) + + if InstrumentationHook.enable_host_buffer: + self._populate_host_buffer(func) + + def _populate_host_buffer(self, function: Any) -> None: + if function and self.metadata_path[function]: + import torch + import struct + import json + + def encode_target(target: Dict[str, Any]) -> int: + #TODO(fywkevin): also account for `arch` + if target["backend"] == "cuda": + return 1 + elif target["backend"] == "hip": + return 2 + return 0 + + alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel() + sampled_warps = self.mode.sampling_options.strip().split(",") + data = {} + with open(self.metadata_path[function], 'r') as file: + data = json.load(file) + + device_type = encode_target(data["target"]) + scratch_mem_size = data["profile_scratch_size"] + total_unit = data["num_warps"] + uid_num = total_unit if self.mode.sampling_strategy == triton_proton.SAMPLING_STRATEGY.NONE else len( + sampled_warps) + block_num = int(alloc_size / scratch_mem_size) + + # Binary trace layout: + # +------------------+ + # | version | 4 bytes + # +------------------+ + # | header_offset | 4 bytes + # +------------------+ + # | header_size | 4 bytes + # +------------------+ + # | payload_offset | 4 bytes + # +------------------+ + # | payload_size | 4 bytes + # +------------------+ + # | device_type | 4 bytes + # +------------------+ + # | block_num | 4 bytes + # +------------------+ + # | total_unit | 4 bytes + # +------------------+ + # | scratch_mem_size | 4 bytes + # +------------------+ + # | uid_num | 4 bytes + # +------------------+ + # | | + # | uid_vec | uid_num * 4 bytes + # | | + # +------------------+ + # | | + # | payload | size_payload bytes + # | | + # +------------------+ + + is_all_warps = self.mode.sampling_options == "" and self.mode.granularity == triton_proton.GRANULARITY.WARP + if is_all_warps: + uid_vec = [i for i in range(total_unit)] + else: + uid_vec = [int(i) for i in sampled_warps] + + header_size = 40 + uid_num * 4 + header_offset = 4 + payload_offset = header_size + payload_size = alloc_size + header_values = [ + VERSION, header_offset, header_size, payload_offset, payload_size, device_type, block_num, total_unit, + scratch_mem_size, uid_num, *uid_vec + ] + header_bytes = struct.pack("I" * len(header_values), *header_values) + + InstrumentationHook.host_buffer = torch.empty(header_size + alloc_size, dtype=torch.uint8, device="cpu") + config_portion = InstrumentationHook.host_buffer[:header_size] + config_portion.copy_(torch.tensor(list(header_bytes), dtype=torch.uint8)) + data_portion = InstrumentationHook.host_buffer[header_size:].view_as(self.buffer) + data_portion.copy_(self.buffer.cpu()) diff --git a/third_party/mthreads/proton/proton/hooks/launch.py b/third_party/mthreads/proton/proton/hooks/launch.py new file mode 100644 index 0000000000..e7d0fd08d6 --- /dev/null +++ b/third_party/mthreads/proton/proton/hooks/launch.py @@ -0,0 +1,121 @@ +from ..state import enter_state, exit_state, COMPUTE_METADATA_SCOPE_NAME +from ..metric import transform_tensor_metrics, set_metric_kernels +from triton.compiler import LazyDict +from .hook import Hook +from triton._C.libproton import proton as libproton +from contextvars import ContextVar +from numbers import Number +import re +from typing import Optional + +op_name = ContextVar("op_name", default=None) +id = ContextVar("id", default=None) +enabled = ContextVar("enabled", default=False) + + +class LaunchHook(Hook): + # Highest priority + priority = 100 + flops_width = [8, 16, 32, 64] + # Historical/derived metrics (e.g., used by viewer utilization computations). + # Launch metadata can carry *additional* metrics; see _extract_metrics(). + metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + ["flops"] + + # Reserved keys that Triton’s runtime always attaches to launch_metadata. + # We never treat these as metrics. + _reserved_metadata_keys = {"name", "function", "stream"} + + # LaunchHook is intended to be a process-wide singleton. HookManager dedupes + # by identity (object instance), so we must ensure repeated LaunchHook() + # constructions return the same instance to avoid double registration. + _instance = None + + def configure(self, *, include: Optional[str] = None, exclude: Optional[str] = None) -> None: + # Regexes over the compiled kernel name (metadata.data["name"]). + self._include_pattern = include + self._exclude_pattern = exclude + self._include_re = re.compile(include) if include else None + self._exclude_re = re.compile(exclude) if exclude else None + + def _matches_kernel_name(self, kernel_name: str) -> bool: + if self._include_re is not None and self._include_re.match(kernel_name) is None: + return False + if self._exclude_re is not None and self._exclude_re.match(kernel_name) is not None: + return False + return True + + @staticmethod + def _is_supported_metric_value(value) -> bool: + # Supported scalar: Python/numpy number-like (bools are allowed but not very useful). + # Supported tensor: objects with a data_ptr() method (e.g., torch.Tensor). + if value is None: + return False + if hasattr(value, "data_ptr"): + return True + return isinstance(value, Number) + + @staticmethod + def _extract_metrics(lazy_metadata: dict) -> dict: + # Accept arbitrary metrics from launch_metadata while filtering out reserved fields + # and unsupported values (e.g., objects/functions). + return { + k: v + for k, v in lazy_metadata.items() + if k not in LaunchHook._reserved_metadata_keys and LaunchHook._is_supported_metric_value(v) + } + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + # Singleton: __init__ is invoked on every construction even when __new__ + # returns an existing instance. + if getattr(self, "_initialized", False): + return + # Ensure filter state is always initialized even if configure() isn't called. + self.configure(include=None, exclude=None) + self._initialized = True + + def init_handle(self, module, function, name: str, metadata_group: dict, hash: str) -> None: + pass + + def activate(self): + pass + + def deactivate(self): + pass + + def enter(self, metadata: LazyDict) -> None: + # Fast path: if the kernel name is already available without evaluating launch_metadata, + # apply include/exclude filters and potentially skip metadata evaluation entirely. + kernel_name = metadata.data.get("name") + if not self._matches_kernel_name(kernel_name): + enabled.set(False) + return + enter_state(COMPUTE_METADATA_SCOPE_NAME) + lazy_metadata = metadata.get() + exit_state() + + kernel_name = lazy_metadata["name"] + # If name wasn't available (or changed), apply filters using the evaluated name. + if not self._matches_kernel_name(kernel_name): + enabled.set(False) + return + + enabled.set(True) + fn_metrics = LaunchHook._extract_metrics(lazy_metadata) + op_name.set(kernel_name) + id.set(libproton.record_scope()) + if fn_metrics: + set_metric_kernels() + scalar_metrics, tensor_metrics = transform_tensor_metrics(fn_metrics) + libproton.enter_op(id.get(), lazy_metadata["name"]) + libproton.add_metrics(id.get(), scalar_metrics, tensor_metrics) + + def exit(self, metadata: LazyDict) -> None: + if not enabled.get(): + return + libproton.exit_op(id.get(), op_name.get()) diff --git a/third_party/mthreads/proton/proton/language.py b/third_party/mthreads/proton/proton/language.py new file mode 100644 index 0000000000..2785938194 --- /dev/null +++ b/third_party/mthreads/proton/proton/language.py @@ -0,0 +1,65 @@ +from triton.language import core as tl +from triton.language.core import builtin +from triton._C.libtriton import proton as triton_proton +from triton.language.semantic import TritonSemantic +from triton.experimental.gluon.language._semantic import GluonSemantic + +from .flags import flags + +_ALL_SEMANTICS = { + "triton": TritonSemantic, + "gluon": GluonSemantic, +} +""" +By default **only Gluon** semantic is enabled. +Instrumenting kernels written in Triton DSL is disable because Triton's higher-level IR undergoes +aggressive compiler rewrites (loop pipelining, instruction re-ordering, IR duplication, etc.). +These transformations can invalidate naïve instrumentation and lead to misleading results. +""" +_SEMANTICS = {_ALL_SEMANTICS["gluon"]} + + +def _check_supported_semantic(semantic): + if not isinstance(semantic, tuple(_SEMANTICS)): + raise TypeError(f"Unsupported semantic type: {type(semantic)}. " + f"Supported semantics are: {_SEMANTICS}") + + +def enable_semantic(semantic_name: str): + _SEMANTICS.add(_ALL_SEMANTICS[semantic_name]) + + +def disable_semantic(semantic_name: str): + _SEMANTICS.remove(_ALL_SEMANTICS[semantic_name]) + + +def record(is_start: tl.constexpr, scope_name: tl.constexpr, semantic): + if not flags.instrumentation_on: + return + _check_supported_semantic(semantic) + is_start = tl._unwrap_if_constexpr(is_start) + scope_name = tl._unwrap_if_constexpr(scope_name) + return tl.tensor(triton_proton.create_proton_record(semantic.builder, is_start, scope_name), tl.void) + + +@builtin +def enter_scope(name: tl.constexpr, _semantic=None): + record(is_start=True, scope_name=name, semantic=_semantic) + + +@builtin +def exit_scope(name: tl.constexpr, _semantic=None): + record(is_start=False, scope_name=name, semantic=_semantic) + + +class scope: + + def __init__(self, name: str, _semantic=None): + self.name = name + self.semantic = _semantic + + def __enter__(self): + enter_scope(self.name, _semantic=self.semantic) + + def __exit__(self, exc_type, exc_value, traceback): + exit_scope(self.name, _semantic=self.semantic) diff --git a/third_party/mthreads/proton/proton/metric.py b/third_party/mthreads/proton/proton/metric.py new file mode 100644 index 0000000000..c99dc95a34 --- /dev/null +++ b/third_party/mthreads/proton/proton/metric.py @@ -0,0 +1,91 @@ +from typing import Any +from triton._C.libproton import proton as libproton +import triton.runtime.driver as driver +import triton.language as tl +import triton +from triton import MockTensor +from .state import exit_state, enter_state, COMPUTE_METADATA_SCOPE_NAME + + +@triton.jit +def tensor_metric_kernel(device_ptr, device_offset_ptr, size: tl.uint64, metric_id: tl.uint64, metric_value_ptr): + device_offset = tl.load(device_offset_ptr) + metric_value = tl.load(metric_value_ptr) + tl.store(device_ptr + device_offset, metric_id) + device_offset = (device_offset + 1) % size + tl.store(device_ptr + device_offset, metric_value) + device_offset = (device_offset + 1) % size + tl.debug_barrier() + tl.store(device_offset_ptr, device_offset) + + +@triton.jit +def scalar_metric_kernel(device_ptr, device_offset_ptr, size: tl.uint64, metric_id: tl.uint64, metric_value: tl.uint64): + device_offset = tl.load(device_offset_ptr) + tl.store(device_ptr + device_offset, metric_id) + device_offset = (device_offset + 1) % size + tl.store(device_ptr + device_offset, metric_value) + device_offset = (device_offset + 1) % size + tl.debug_barrier() + tl.store(device_offset_ptr, device_offset) + + +def _get_kernel(kernel_fn, *args): + kernel = kernel_fn.warmup(*args, grid=(1, ), num_warps=1) + kernel._init_handles() + return kernel.function + + +def set_metric_kernels(): + mock_ptr = MockTensor(tl.uint64) + mock_metric_id = 0 + mock_size = 1 + tensor_metric_kernel_fn = _get_kernel( + tensor_metric_kernel, + mock_ptr, + mock_ptr, + mock_size, + mock_metric_id, + mock_ptr, + ) + scalar_metric_kernel_fn = _get_kernel( + scalar_metric_kernel, + mock_ptr, + mock_ptr, + mock_size, + mock_metric_id, + mock_metric_id, + ) + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + libproton.set_metric_kernels(tensor_metric_kernel_fn, scalar_metric_kernel_fn, stream) + + +class _TensorMetric(libproton.TensorMetric): + # Hold a reference to the backing tensor so its device memory stays alive. + def __init__(self, value, metric_index): + super().__init__(value.data_ptr(), metric_index) + self._value = value + + +def transform_tensor_metrics(metrics: dict[str, Any]) -> tuple[dict[str, Any], dict[str, libproton.TensorMetric]]: + tensor_metrics = {} + scalar_metrics: dict[str, Any] = {} + for key, value in metrics.items(): + if hasattr(value, "data_ptr"): # tensor + if value.device.type == "cpu": + scalar_metrics[key] = value + else: # device tensor + enter_state(COMPUTE_METADATA_SCOPE_NAME) + # implicit casting to double or int64 tensors + if value.is_floating_point(): + value = value.double() + metric_index = libproton.metric_double_index + else: + value = value.long() + metric_index = libproton.metric_int64_index + exit_state() + tensor_metrics[key] = _TensorMetric(value, metric_index) + else: + scalar_metrics[key] = value + return scalar_metrics, tensor_metrics diff --git a/third_party/mthreads/proton/proton/mode.py b/third_party/mthreads/proton/proton/mode.py new file mode 100644 index 0000000000..ff41d58872 --- /dev/null +++ b/third_party/mthreads/proton/proton/mode.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass, field +from triton._C.libtriton import proton as triton_proton +from typing import List +from enum import Enum + +metric_types = {"cycle": triton_proton.METRIC_TYPE.CYCLE} + +buffer_strategies = { + "circular": triton_proton.BUFFER_STRATEGY.CIRCULAR, + "flush": triton_proton.BUFFER_STRATEGY.FLUSH, +} + +buffer_types = { + "shared": triton_proton.BUFFER_TYPE.SHARED, + "global": triton_proton.BUFFER_TYPE.GLOBAL, +} + +sampling_strategies = { + "none": triton_proton.SAMPLING_STRATEGY.NONE, + "selective": triton_proton.SAMPLING_STRATEGY.SELECTIVE, +} + +granularities = { + "cta": triton_proton.GRANULARITY.CTA, + "warp": triton_proton.GRANULARITY.WARP, + "warp_2": triton_proton.GRANULARITY.WARP_2, + "warp_4": triton_proton.GRANULARITY.WARP_4, + "warp_8": triton_proton.GRANULARITY.WARP_8, + "warp_group": triton_proton.GRANULARITY.WARP_GROUP, + "warp_group_2": triton_proton.GRANULARITY.WARP_GROUP_2, + "warp_group_4": triton_proton.GRANULARITY.WARP_GROUP_4, + "warp_group_8": triton_proton.GRANULARITY.WARP_GROUP_8, +} + + +class Optimize(Enum): + TIMESHIFT = "time_shift" + SCHED_STORES = "sched_stores" + SCHED_BARRIERS = "sched_barriers" + CLOCK32 = "clock32" + + def __str__(self): + return self.value + + +optimizations = { + "time_shift": Optimize.TIMESHIFT, + "sched_stores": Optimize.SCHED_STORES, + "sched_barriers": Optimize.SCHED_BARRIERS, + "clock32": Optimize.CLOCK32, +} + + +@dataclass(frozen=True) +class BaseMode: + name: str + + +@dataclass(frozen=True) +class PCSampling(BaseMode): + name: str = field(default="pcsampling", init=False) + interval: int = 1000 + + def __post_init__(self): + if self.interval <= 0: + raise ValueError("Interval must be a positive integer.") + + def __str__(self): + return f"{self.name}:interval={self.interval}" + + +@dataclass(frozen=True) +class InstrumentationMode(BaseMode): + """Common base class for instrumentation modes with shared configuration.""" + metric_type: triton_proton.METRIC_TYPE = triton_proton.METRIC_TYPE.CYCLE + sampling_strategy: triton_proton.SAMPLING_STRATEGY = triton_proton.SAMPLING_STRATEGY.NONE + sampling_options: str = "" + granularity: triton_proton.GRANULARITY = triton_proton.GRANULARITY.WARP + buffer_strategy: triton_proton.BUFFER_STRATEGY = triton_proton.BUFFER_STRATEGY.CIRCULAR + buffer_type: triton_proton.BUFFER_TYPE = triton_proton.BUFFER_TYPE.SHARED + buffer_size: int = 0 + optimizations: List[Optimize] = field(default_factory=list) + + def __post_init__(self): + # automatically map string inputs to enums using the global lookup dicts + mappings = [ + ("metric_type", metric_types), + ("sampling_strategy", sampling_strategies), + ("granularity", granularities), + ("buffer_strategy", buffer_strategies), + ("buffer_type", buffer_types), + ] + for field_name, lookup in mappings: + value = getattr(self, field_name) + if isinstance(value, str): + if value not in lookup: + raise ValueError(f"Unknown {field_name}: {value}") + object.__setattr__(self, field_name, lookup[value]) + + values_str = getattr(self, "optimizations") + if isinstance(values_str, str): + values = [value.strip() for value in values_str.split(",")] if len(values_str) > 0 else [] + for value in values: + if value not in optimizations: + raise ValueError(f"Unknown optimization: {value}") + object.__setattr__(self, "optimizations", [optimizations[value] for value in values]) + + def __str__(self): + optimizations_str = ",".join([str(opt) for opt in self.optimizations]) + return (f"{self.name}:metric_type={self.metric_type}:sampling_strategy={self.sampling_strategy}" + f":sampling_options={self.sampling_options}:granularity={self.granularity}" + f":buffer_strategy={self.buffer_strategy}:buffer_type={self.buffer_type}" + f":buffer_size={self.buffer_size}:optimizations={optimizations_str}") + + +@dataclass(frozen=True) +class Default(InstrumentationMode): + name: str = field(default="default", init=False) + + +@dataclass(frozen=True) +class MMA(InstrumentationMode): + name: str = field(default="mma", init=False) diff --git a/third_party/mthreads/proton/proton/profile.py b/third_party/mthreads/proton/proton/profile.py new file mode 100644 index 0000000000..29e9a5b6aa --- /dev/null +++ b/third_party/mthreads/proton/proton/profile.py @@ -0,0 +1,262 @@ +import functools +import triton + +from triton._C.libproton import proton as libproton # type: ignore +from triton._C.libtriton import getenv # type: ignore +from .flags import flags +from .hooks import HookManager, LaunchHook, InstrumentationHook +from .hooks.hook import Hook +from .mode import BaseMode +from typing import Optional, Union + +DEFAULT_PROFILE_NAME = "proton" + + +def _select_backend() -> str: + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "cuda": + return "cupti" + elif backend == "hip": + return "roctracer" + else: + raise ValueError("No backend is available for the current target.") + + +def _get_mode_str(backend: str, mode: Optional[Union[str, BaseMode]]) -> str: + if backend == "instrumentation": + prefix = triton.runtime.driver.active.get_current_target().backend + return f"{prefix}:{mode}" if mode else prefix + return str(mode) if mode else "" + + +def _check_env(backend: str) -> None: + if backend == "roctracer": + hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] + for env in hip_device_envs: + if getenv(env, None) is not None: + raise ValueError( + f"Proton does not work when the environment variable {env} is set on AMD GPUs. Please unset it and use `ROCR_VISIBLE_DEVICES` instead" + ) + + # Ensure default envs are set for Proton knobs if not already set by the user. + for attr, desc in triton.knobs.proton.knob_descriptors.items(): + key = desc.key + if getenv(key, None) is None: + val = getattr(triton.knobs.proton, attr) + if val is not None: + if env_val := triton.knobs.toenv(val): + triton.knobs.setenv(key, env_val[0]) + + +def start( + name: Optional[str] = None, + *, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + mode: Optional[Union[str, BaseMode]] = None, + hook: Optional[Union[str, Hook]] = None, +) -> Optional[int]: + """ + Start profiling with the given name and backend. + + Usage: + + ```python + proton.start("my_profile") + # do something + proton.finalize() + ``` + + Args: + name (str, optional): The name (with path) of the profiling session. + If not provided, the default name is "~/proton.", where suffix is the default + format according to the data type. For example, if data is "tree", the default name is "~/proton.hatchet". + context (str, optional): The context to use for profiling. + Available options are ["shadow", "python"]. + Defaults to "shadow". + data (str, optional): The data structure to use for profiling. + Available options are ["tree", "trace"]. + Defaults to "tree". + backend (str, optional): The backend to use for profiling. + Available options are [None, "cupti", "roctracer", "instrumentation"]. + Defaults to None, which automatically selects the backend matching the current active runtime. + mode (Union[str, BaseMode], optional): The "mode" to use for profiling, which is specific to the backend. + Can be a string or an instance of BaseMode (or any subclass thereof). + Defaults to None. + For "cupti", available options are [None, "pcsampling", "periodic_flushing"]. + For "roctracer", available options are ["periodic_flushing"]. + For "instrumentation", available options are [None]. + Each mode has a set of control knobs following with the mode name. + For example, "periodic_flushing" mode has a knob: + - format: The output format of the profiling results. Available options are ["hatchet", "hatchet_msgpack", "chrome_trace"]. Default is "hatchet". + The can be set via `mode="periodic_flushing:format=chrome_trace"`. + hook (Union[str, Hook], optional): The hook to use for profiling. + You may pass either: + - a string hook name, e.g. "triton" (kernel launch metadata), or + - a custom Hook instance. + Defaults to None. + Returns: + session (Optional[int]): The session ID of the profiling session, or None if profiling is disabled. + """ + if flags.command_line or triton.knobs.proton.disable: + # Ignore the start() call if the script is run from the command line or profiling is disabled. + return None + + flags.profiling_on = True + + name = DEFAULT_PROFILE_NAME if name is None else name + backend = _select_backend() if backend is None else backend + # Convert mode to its string representation for libproton's runtime + mode_str = _get_mode_str(backend, mode) + + _check_env(backend) + + session = libproton.start(name, context, data, backend, mode_str) + + if isinstance(hook, Hook): + HookManager.register(hook, session) + elif hook == "triton": + HookManager.register(LaunchHook(), session) + elif hook is not None: + raise ValueError(f"Unsupported hook: {hook!r}") + if backend == "instrumentation": + HookManager.register(InstrumentationHook(mode), session) + + return session + + +def activate(session: Optional[int] = None) -> None: + """ + Activate the specified session. + The profiling session will be active and data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to None (all sessions) + + Returns: + None + """ + if flags.command_line and session != 0: + raise ValueError("Only one session can be activated when running from the command line.") + + HookManager.activate(session) + + if session is None: + libproton.activate_all() + else: + libproton.activate(session) + + +def deactivate(session: Optional[int] = None, flushing: bool = False) -> None: + """ + Stop the specified session. + The profiling session's data will still be in the memory, but no more data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to None (all sessions) + flushing (bool): Whether to flush the profiling data before deactivating. Defaults to True. + + Returns: + None + """ + if flags.command_line and session != 0: + raise ValueError("Only one session can be deactivated when running from the command line.") + + HookManager.deactivate(session) + + if session is None: + libproton.deactivate_all(flushing) + else: + libproton.deactivate(session, flushing) + + +def finalize(session: Optional[int] = None, output_format: Optional[str] = "") -> None: + """ + Finalizes a profiling session. + Flush and write the profiling data to the file specified by the session name. + + Args: + session (int, optional): The session ID to finalize. If None, all sessions are finalized. Defaults to None. + output_format (str, optional): The output format for the profiling results. + Available options are ["hatchet", "hatchet_msgpack", "chrome_trace"]. + + Returns: + None + """ + HookManager.unregister(session) + + if session is None: + flags.profiling_on = False + libproton.finalize_all(output_format) + else: + if flags.command_line and session != 0: + raise ValueError("Only one session can be finalized when running from the command line.") + libproton.finalize(session, output_format) + + +def _profiling( + func, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + mode: Optional[str] = None, + hook: Optional[Union[str, Hook]] = None, +): + """ + Context manager for profiling. Internally use only. + + Args: + See start() for the arguments. + + Returns: + wrapper (function): The wrapped function. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = start(name, context=context, data=data, backend=backend, mode=mode, hook=hook) + ret = func(*args, **kwargs) + deactivate(session) + return ret + + return wrapper + + +def profile( + func=None, + *, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + mode: Optional[str] = None, + hook: Optional[Union[str, Hook]] = None, +): + """ + Decorator for profiling. + + Usage: + + ```python + @proton.profile + def foo(): + pass + ``` + + Args: + See start() for the arguments. + + Returns: + decorator (function): The decorator function. + """ + if func is None: + # It's being used with parentheses, so return a decorator + def decorator(f): + return _profiling(f, name=name, context=context, data=data, backend=backend, mode=mode, hook=hook) + + return decorator + else: + # It's being used without parentheses, so apply the decorator directly + return _profiling(func, name=name, context=context, data=data, backend=backend, mode=mode, hook=hook) diff --git a/third_party/mthreads/proton/proton/proton.py b/third_party/mthreads/proton/proton/proton.py new file mode 100644 index 0000000000..a7689288da --- /dev/null +++ b/third_party/mthreads/proton/proton/proton.py @@ -0,0 +1,88 @@ +import argparse +import sys +import os +import runpy +import traceback +from .profile import start, finalize, _select_backend +from .flags import flags + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="The proton command utility for profiling scripts and pytest tests.", usage=""" + proton [options] script.py [script_args] [script_options] + proton [options] pytest [pytest_args] [script_options] + python -m triton.profiler.proton [options] script.py [script_args] [script_options] +""", formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") + parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, + choices=["cupti", "roctracer", "instrumentation"]) + parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", + choices=["shadow", "python"]) + parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None) + parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree", "trace"]) + parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) + parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') + args = parser.parse_args() + return args, args.target_args + + +def is_pytest(script): + return os.path.basename(script) == 'pytest' + + +def execute_as_main(script, args): + script_path = os.path.abspath(script) + + original_argv = sys.argv + sys.argv = [script] + args + # Append the script's directory in case the script uses relative imports + sys.path.append(os.path.dirname(script_path)) + + # Execute in the isolated environment + try: + runpy.run_path(script, run_name="__main__") + except Exception as e: + print("An error occurred while executing the script:") + traceback.print_exception(e) + return 1 + except SystemExit as e: + return e.code + except KeyboardInterrupt: + return 1 + finally: + sys.argv = original_argv + return 0 + + +def do_setup_and_execute(target_args): + # Set the command line mode to avoid any `start` calls in the script. + flags.command_line = True + + script = target_args[0] + script_args = target_args[1:] if len(target_args) > 1 else [] + if is_pytest(script): + import pytest + return pytest.main(script_args) + else: + return execute_as_main(script, script_args) + + +def run_profiling(args, target_args): + backend = args.backend if args.backend else _select_backend() + + start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook) + + exitcode = do_setup_and_execute(target_args) + + finalize() + sys.exit(exitcode) + + +def main(): + args, target_args = parse_arguments() + run_profiling(args, target_args) + + +if __name__ == "__main__": + main() diff --git a/third_party/mthreads/proton/proton/scope.py b/third_party/mthreads/proton/proton/scope.py new file mode 100644 index 0000000000..881cf0935a --- /dev/null +++ b/third_party/mthreads/proton/proton/scope.py @@ -0,0 +1,133 @@ +import threading +import time +from functools import wraps +from typing import Optional, Union, Any + +from .flags import flags +from .metric import transform_tensor_metrics, set_metric_kernels +from triton._C.libproton import proton as libproton + +thread_local_scopes = threading.local() + +MetricValueType = Union[float, int] + + +class scope: + """ + A context manager and decorator for entering and exiting a scope. + + Usage: + context manager: + ```python + with proton.scope("test0", {metric_name: metric_value}): + foo[1,](x, y) + ``` + + decorator: + ```python + @proton.scope("test0", {metric_name: metric_value}) + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the scope. + metrics (dict[str, float], optional): The metrics of the scope. Default is None. + """ + + def __init__(self, name: str, metrics: Optional[dict[str, Any]] = None) -> None: + self.name = name + self.metrics = metrics + self.id = None + + def _enter_scope(self): + if not flags.profiling_on: + return + self.id = libproton.record_scope() + libproton.enter_scope(self.id, self.name) + if self.metrics: + set_metric_kernels() + libproton.add_metrics(self.id, *transform_tensor_metrics(self.metrics)) + + def _exit_scope(self): + if not flags.profiling_on or self.id is None: + return + libproton.exit_scope(self.id, self.name) + + def __enter__(self): + self._enter_scope() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._exit_scope() + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + self._enter_scope() + try: + return func(*args, **kwargs) + finally: + self._exit_scope() + + return wrapper + + +class cpu_timed_scope(scope): + """ + A scope that measures elapsed time (cpu_time). + + Args: + name (str): The name of the scope. + metrics (dict[str, float], optional): Additional metrics to add. Default is None. + """ + + def __init__(self, name: str, metrics: Optional[dict[str, Any]] = None) -> None: + super().__init__(name, metrics) + self.start_time = None + if metrics and "cpu_time" in metrics: + raise ValueError("The metric name 'cpu_time' is reserved.") + + def _enter_scope(self): + if not flags.profiling_on: + return + self.start_time = time.time_ns() + super()._enter_scope() + + def _exit_scope(self): + if not flags.profiling_on: + return + if self.start_time is not None: + cpu_time = time.time_ns() - self.start_time + libproton.add_metrics(self.id, {"cpu_time (ns)(exc)": cpu_time}) + super()._exit_scope() + + +def enter_scope(name: str, *, metrics: Optional[dict[str, Any]] = None) -> Optional[int]: + if not flags.profiling_on: + return None + id = libproton.record_scope() + thread_local_scopes.scopes = getattr(thread_local_scopes, "scopes", []) + thread_local_scopes.scopes.append((id, name)) + libproton.enter_scope(id, name) + if metrics: + set_metric_kernels() + libproton.add_metrics(id, *transform_tensor_metrics(metrics)) + return id + + +def exit_scope(name: Optional[str] = None, *, metrics: Optional[dict[str, Any]] = None) -> Optional[int]: + # `name` is an optional argument here, only to match the counterpart in enter_scope to make the API consistent with `proton.language.exit_scope` + if not flags.profiling_on: + return None + id, popped_name = thread_local_scopes.scopes.pop() + if name and name != popped_name: + raise ValueError(f"Scope name mismatch: {name} != {popped_name}") + elif not name: + name = popped_name + if metrics: + set_metric_kernels() + libproton.add_metrics(id, *transform_tensor_metrics(metrics)) + libproton.exit_scope(id, name) + return id diff --git a/third_party/mthreads/proton/proton/specs.py b/third_party/mthreads/proton/proton/specs.py new file mode 100644 index 0000000000..b30c3416d8 --- /dev/null +++ b/third_party/mthreads/proton/proton/specs.py @@ -0,0 +1,69 @@ +flops_by_device = { + "CUDA": { + "80": + lambda width, **kwargs: 624e12 / (width / 8), + "89": + lambda width, **kwargs: (330.3 * 1e12) / (width / 8), # TODO(Keren): Implement fp16 acc-> 660.6 fp8 + "90": + lambda width, num_sms, clock_rate, **kwargs: ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / + (width / 8), + "100": + lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8), + } +} + +amd_bps_by_arch = { + 'gfx90a': 3.2 * 1e12, + 'gfx942': 5.3 * 1e12, + 'gfx950': 8.0 * 1e12, +} + +# FP8 Matrix Performance(FLOPS/clock/CU) +# For gfx90a we use the performance of INT8 since it doesn't support FP8 matrix operations. +amd_fp8_flops_by_arch = {'gfx90a': 1024, 'gfx942': 4096, 'gfx950': 8192} + + +def max_flops(device_type, arch, width, num_sms, clock_rate): + """ + Calculate the maximum FLOPS for a given device type and width. + + Args: + device_type (str): The type of device (e.g., "CUDA", "HIP"). + arch (str): The architecture of the device (e.g., "80", "90"). + width (int): The width in bits. + num_sms (int): The number of streaming multiprocessors. + clock_rate (float): The clock rate in GHz. + + Returns: + float: The maximum FLOPS for the given device type and width. + """ + if device_type == "HIP": + return amd_fp8_flops_by_arch[arch] * num_sms * clock_rate * 1e3 / (width / 8) + + if device_type not in flops_by_device: + raise ValueError(f"Unsupported device type: {device_type}") + + if arch not in flops_by_device[device_type]: + raise ValueError(f"Unsupported architecture: {arch}") + + flops_func = flops_by_device[device_type][arch] + + return flops_func(width, num_sms=num_sms, clock_rate=clock_rate) + + +def max_bps(device_type, arch, bus_width, memory_clock_rate): + """ + Calculate the maximum bytes per second for a given bus width and memory clock rate. + + Args: + bus_width (int): The bus width in bits. + memory_clock_rate (float): The memory clock rate in GHz. + + Returns: + float: The maximum bytes per second. + """ + if device_type == "CUDA": + return 2 * bus_width * memory_clock_rate * 1e3 / 8 + else: + assert device_type == "HIP" + return amd_bps_by_arch[arch] diff --git a/third_party/mthreads/proton/proton/state.py b/third_party/mthreads/proton/proton/state.py new file mode 100644 index 0000000000..6232f452c5 --- /dev/null +++ b/third_party/mthreads/proton/proton/state.py @@ -0,0 +1,69 @@ +from triton._C.libproton import proton as libproton +from .flags import flags +from functools import wraps + +COMPUTE_METADATA_SCOPE_NAME = "__proton_launch_metadata" + + +class state: + """ + A context manager and decorator for entering and exiting a state. + + Usage: + context manager: + ```python + with proton.state("test0"): + foo[1,](x, y) + ``` + + decorator: + ```python + @proton.state("test0") + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the state. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + if not flags.profiling_on: + return self + libproton.enter_state(self.name) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if not flags.profiling_on: + return + libproton.exit_state() + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + if flags.profiling_on: + libproton.enter_state(self.name) + ret = func(*args, **kwargs) + if flags.profiling_on: + libproton.exit_state() + return ret + + return wrapper + + +class metadata_state(state): + + def __init__(self) -> None: + super().__init__(COMPUTE_METADATA_SCOPE_NAME) + + +def enter_state(name: str) -> None: + libproton.enter_state(name) + + +def exit_state() -> None: + libproton.exit_state() diff --git a/third_party/mthreads/proton/proton/viewer.py b/third_party/mthreads/proton/proton/viewer.py new file mode 100644 index 0000000000..fbe9f4bda2 --- /dev/null +++ b/third_party/mthreads/proton/proton/viewer.py @@ -0,0 +1,428 @@ +import argparse +from collections import namedtuple +import json +import pandas as pd + +try: + import hatchet as ht + from hatchet.query import NegationQuery +except ImportError: + raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.") +import numpy as np +from triton.profiler.state import COMPUTE_METADATA_SCOPE_NAME +from triton.profiler.hooks.launch import LaunchHook +from triton.profiler import specs + + +def match_available_metrics(metrics, inclusive_metrics, exclusive_metrics): + ret = [] + if not isinstance(metrics, list): + metrics = [metrics] + if metrics: + for metric in metrics: + metric = metric.lower() + for raw_metric in inclusive_metrics + exclusive_metrics: + suffix = " (inc)" if raw_metric in inclusive_metrics else "" + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + if metric in (raw_metric, raw_metric_no_unit): + ret.append(raw_metric + suffix) + break + if len(ret) == 0: + raise RuntimeError(f"Metric {metric} is not found. Use the --list flag to list available metrics") + return ret + + +def remove_frames(database: json): + # We first fine frames that match either one of the two conditions: + # 1. The frame name is COMPUTE_METADATA_SCOPE_NAME + # 2. The frame has no metrics and no children + # Then we go up from the located nodes and remove the parents if all children were + # metadata nodes + def remove_frame_helper(node): + if "frame" not in node: + return node + if node["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME: + return None + if len(node["metrics"]) == 0 and len(node["children"]) == 0: + return None + children = node.get("children", []) + new_children = [] + for child in children: + new_child = remove_frame_helper(child) + if new_child is not None: + new_children.append(new_child) + if len(new_children) > 0 or len(children) == 0: + node["children"] = new_children + return node + return None + + new_database = [] + for node in database: + new_node = remove_frame_helper(node) + if new_node is not None: + new_database.append(new_node) + return new_database + + +def get_raw_metrics(database) -> tuple[ht.GraphFrame, list[str], list[str], dict]: + database = remove_frames(database) + device_info = {} if len(database) < 2 else database.pop(1) + gf = ht.GraphFrame.from_literal(database) + inclusive_metrics = gf.show_metric_columns() + exclusive_metrics = [metric for metric in gf.dataframe.columns if metric not in inclusive_metrics] + return gf, inclusive_metrics, exclusive_metrics, device_info + + +def get_min_time_flops(df, device_info): + min_time_flops = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + arch = device_info[device_type][device_index]["arch"] + num_sms = device_info[device_type][device_index]["num_sms"] + clock_rate = device_info[device_type][device_index]["clock_rate"] + for width in LaunchHook.flops_width: + idx = df["device_id"] == device_index + device_frames = df[idx] + if f"flops{width}" not in device_frames.columns: + continue + max_flops = specs.max_flops(device_type, arch, width, num_sms, clock_rate) + min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops + return min_time_flops + + +def get_min_time_bytes(df, device_info): + min_time_bytes = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + idx = df["device_id"] == device_index + device_frames = df[idx] + device = device_info[device_type][device_index] + memory_clock_rate = device["memory_clock_rate"] # in khz + bus_width = device["bus_width"] # in bits + peak_bandwidth = specs.max_bps(device_type, device['arch'], bus_width, memory_clock_rate) + min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth + return min_time_bytes + + +FactorDict = namedtuple("FactorDict", ["name", "factor"]) +time_factor_dict = FactorDict("time", {"time/s": 1, "time/ms": 1e-3, "time/us": 1e-6, "time/ns": 1e-9}) +avg_time_factor_dict = FactorDict("avg_time", {f"avg_{key}": value for key, value in time_factor_dict.factor.items()}) +cpu_time_factor_dict = FactorDict("cpu_time", + {"cpu_time/s": 1, "cpu_time/ms": 1e-3, "cpu_time/us": 1e-6, "cpu_time/ns": 1e-9}) +avg_cpu_time_factor_dict = FactorDict("avg_cpu_time", + {f"avg_{key}": value + for key, value in cpu_time_factor_dict.factor.items()}) +bytes_factor_dict = FactorDict("bytes", {"byte/s": 1, "gbyte/s": 1e9, "tbyte/s": 1e12}) + +derivable_metrics = { + **{key: bytes_factor_dict + for key in bytes_factor_dict.factor.keys()}, +} + +# FLOPS have a specific width to their metric +default_flop_factor_dict = {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12} +derivable_metrics.update( + {key: FactorDict("flops", default_flop_factor_dict) + for key in default_flop_factor_dict.keys()}) +for width in LaunchHook.flops_width: + factor_name = f"flops{width}" + factor_dict = {f"flop{width}/s": 1, f"gflop{width}/s": 1e9, f"tflop{width}/s": 1e12} + derivable_metrics.update({key: FactorDict(factor_name, factor_dict) for key in factor_dict.keys()}) + + +def derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info): + derived_metrics = [] + + def get_time_seconds(df, metric, factor_dict): + time_metric_name = match_available_metrics(metric, inclusive_metrics, exclusive_metrics)[0] + time_unit = factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0] + return df[time_metric_name] * factor_dict.factor[time_unit] + + for metric in metrics: + if metric == "util": # exclusive + min_time_bytes = get_min_time_bytes(gf.dataframe, device_info) + min_time_flops = get_min_time_flops(gf.dataframe, device_info) + time_sec = get_time_seconds(gf.dataframe, "time", time_factor_dict) + internal_frame_indices = gf.dataframe["device_id"].isna() + gf.dataframe["util"] = min_time_flops["min_time"].combine(min_time_bytes["min_time"], max) / time_sec + gf.dataframe.loc[internal_frame_indices, "util"] = np.nan + derived_metrics.append("util") + elif metric in derivable_metrics: # flop/s, byte/s, inclusive + derivable_metric = derivable_metrics[metric] + metric_name = derivable_metric.name + metric_factor_dict = derivable_metric.factor + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + gf.dataframe[f"{metric} (inc)"] = (gf.dataframe[matched_metric_name] / + (get_time_seconds(gf.dataframe, "time", time_factor_dict)) / + metric_factor_dict[metric]) + derived_metrics.append(f"{metric} (inc)") + elif (metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor + or metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor): # inclusive + is_cpu = metric in cpu_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor + is_avg = metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor + + factor_dict = ((avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu else + (avg_time_factor_dict if is_avg else time_factor_dict)) + metric_name = "cpu_time" if is_cpu else "time" + metric_time_unit = factor_dict.name + "/" + metric.split("/")[1] + + time_value = get_time_seconds(gf.dataframe, metric_name, factor_dict) + if is_avg: + time_value = time_value / gf.dataframe["count (inc)"] + + gf.dataframe[f"{metric} (inc)"] = time_value / factor_dict.factor[metric_time_unit] + derived_metrics.append(f"{metric} (inc)") + else: + metric_name_and_unit = metric.split("/") + metric_name = metric_name_and_unit[0] + if len(metric_name_and_unit) > 1: # percentage, exclusive or inclusive + metric_unit = metric_name_and_unit[1] + if metric_unit != "%": + raise ValueError(f"Unsupported unit {metric_unit}") + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + single_frame = gf.dataframe[matched_metric_name] + suffix = "" + if "(inc)" in matched_metric_name: + suffix = " (inc)" + total = gf.dataframe[matched_metric_name].iloc[0] + else: + total = gf.dataframe[matched_metric_name].sum() + gf.dataframe[metric + suffix] = (single_frame / total) * 100.0 + derived_metrics.append(metric + suffix) + else: + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + derived_metrics.append(matched_metric_name) + + # Update derived metrics to the graph frame + for derived_metric in derived_metrics: + if derived_metric.endswith("(inc)"): + gf.inc_metrics.append(derived_metric) + else: + gf.exc_metrics.append(derived_metric) + + return derived_metrics + + +def format_frames(gf, format): + if format == "file_function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split("/")[-1]) + elif format == "function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split(":")[-1]) + elif format == "file_function": + gf.dataframe["name"] = gf.dataframe["name"].apply( + lambda x: f"{x.split('/')[-1].split(':')[0]}@{x.split('@')[-1].split(':')[0]}") + return gf + + +def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): + if include: + query = f""" +MATCH ("*")->(".", p)->("*") +WHERE p."name" =~ "{include}" +""" + gf = gf.filter(query, squash=True) + if exclude: + inclusion_query = f""" +MATCH (".", p)->("*") +WHERE p."name" =~ "{exclude}" +""" + query = NegationQuery(inclusion_query) + gf = gf.filter(query, squash=True) + if threshold: + query = ["*", {metric: f">= {threshold}"}] + gf = gf.filter(query, squash=True) + return gf + + +def emit_warnings(gf, metrics): + if "bytes (inc)" in metrics: + byte_values = gf.dataframe["bytes (inc)"].values + min_byte_value = np.nanmin(byte_values) + if min_byte_value < 0: + print("Warning: Negative byte values detected, this is usually the result of a datatype overflow\n") + + +def print_tree(gf, metrics, depth=100, format=None, print_sorted=False): + gf = format_frames(gf, format) + print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + + if print_sorted: + print("Sorted kernels by metric " + metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + for row in range(1, len(sorted_df)): + kernel_name = (sorted_df.iloc[row]["name"][:100] + + "..." if len(sorted_df.iloc[row]["name"]) > 100 else sorted_df.iloc[row]["name"]) + print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]])) + emit_warnings(gf, metrics) + + +def read(filename): + with open(filename, "r") as f: + database = json.load(f) + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(database) + assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" + gf.update_inclusive_columns() + return gf, inclusive_metrics, exclusive_metrics, device_info + + +def parse(metrics, filename, include=None, exclude=None, threshold=None): + gf, inclusive_metrics, exclusive_metrics, device_info = read(filename) + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + # TODO: generalize to support multiple metrics, not just the first one + gf = filter_frames(gf, include, exclude, threshold, metrics[0]) + return gf, metrics + + +def apply_diff_profile(gf, derived_metrics, diff_file, metrics, include, exclude, threshold): + # Compute the diff against a secondary profile while keeping derived metrics consistent. + gf2, _ = parse(metrics, diff_file, include, exclude, threshold) + + derived_inc_metrics = [metric for metric in derived_metrics if metric.endswith("(inc)")] + derived_exc_metrics = [metric for metric in derived_metrics if not metric.endswith("(inc)")] + + gf.inc_metrics = derived_inc_metrics + gf.exc_metrics = derived_exc_metrics + gf2.inc_metrics = derived_inc_metrics + gf2.exc_metrics = derived_exc_metrics + return gf.sub(gf2) + + +def show_metrics(file_name): + with open(file_name, "r") as f: + database = json.load(f) + _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(database) + print("Available inclusive metrics:") + if inclusive_metrics: + for raw_metric in inclusive_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + print(f"- {raw_metric_no_unit}") + print("Available exclusive metrics:") + if exclusive_metrics: + for raw_metric in exclusive_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + print(f"- {raw_metric_no_unit}") + + +def main(): + argparser = argparse.ArgumentParser( + description="Performance data viewer for proton profiles.", + formatter_class=argparse.RawTextHelpFormatter, + ) + argparser.add_argument( + "-l", + "--list", + action="store_true", + help="""List available metrics. Metric names are case insensitive and ignore units. +Derived metrics can be created when source metrics are available. +- time/s, time/ms, time/us, time/ns: time +- avg_time/s, avg_time/ms, avg_time/us, avg_time/ns: time / count +- flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time +- byte/s, gbyte/s, tbyte/s: bytes / time +- util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) +- /%%: frame(metric) / sum(metric). Only available for inclusive metrics (e.g. time) +""", + ) + argparser.add_argument( + "-m", + "--metrics", + type=str, + default=None, + help="""At maximum two metrics can be specified, separated by comma. +There are two modes: +1) Choose the output metric to display. It's case insensitive and ignore units. +2) Derive a new metric from existing metrics. +""", + ) + argparser.add_argument( + "-i", + "--include", + type=str, + default=None, + help= + """Find frames that match the given regular expression and return all nodes in the paths that pass through the matching frames. +For example, the following command will display all paths that contain frames that contains "test": +``` +proton-viewer -i ".*test.*" path/to/file.json +``` +""", + ) + argparser.add_argument( + "-e", + "--exclude", + type=str, + default=None, + help="""Exclude frames that match the given regular expression and their children. +For example, the following command will exclude all paths starting from frames that contains "test": +``` +proton-viewer -e ".*test.*" path/to/file.json +``` +""", + ) + argparser.add_argument( + "-t", + "--threshold", + type=float, + default=None, + help= + "Exclude frames(kernels) whose metrics are below the given threshold. This filter only applies on the first metric.", + ) + argparser.add_argument( + "-d", + "--depth", + type=int, + default=100, + help="The depth of the tree to display", + ) + argparser.add_argument( + "-f", + "--format", + type=str, + choices=["full", "file_function_line", "function_line", "file_function"], + default="full", + help="""Formatting the frame name. +- full: include the path, file name, function name and line number. +- file_function_line: include the file name, function name and line number. +- function_line: include the function name and line number. +- file_function: include the file name and function name. +""", + ) + argparser.add_argument( + "--print-sorted", + action="store_true", + default=False, + help="Sort output by metric value instead of chronologically", + ) + argparser.add_argument( + "--diff-profile", + "-diff", + type=str, + default=None, + help="Compare two profiles. When used as 'proton-viewer -m time -diff file1.log file2.log', " + "computes the difference: file2['time'] - file1['time']", + ) + + args, target_args = argparser.parse_known_args() + assert len(target_args) == 1, "Must specify a file to read" + + file_name = target_args[0] + metrics = args.metrics.split(",") if args.metrics else None + include = args.include + exclude = args.exclude + threshold = args.threshold + depth = args.depth + format = args.format + diff = args.diff_profile + print_sorted = args.print_sorted + if include and exclude: + raise ValueError("Cannot specify both include and exclude") + if args.list: + show_metrics(file_name) + elif metrics: + gf, derived_metrics = parse(metrics, file_name, include, exclude, threshold) + if diff: + gf = apply_diff_profile(gf, derived_metrics, diff, metrics, include, exclude, threshold) + print_tree(gf, derived_metrics, depth, format, print_sorted) + + +if __name__ == "__main__": + main() diff --git a/third_party/mthreads/proton/scripts/dump_ttgir.sh b/third_party/mthreads/proton/scripts/dump_ttgir.sh new file mode 100755 index 0000000000..80c89dc05d --- /dev/null +++ b/third_party/mthreads/proton/scripts/dump_ttgir.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Usage: ./dump_ttgir.sh python + +cmd="$*" +if [ -z "$cmd" ]; then + echo "Example usage: $0 python " + exit 1 +fi + +DUMP_DIR="$PWD/ttgir_dump" +mkdir -p "$DUMP_DIR" + +TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 TRITON_DUMP_DIR=$DUMP_DIR $cmd +# Iterate over all subdirectories in $DUMP_DIR and remove all except the .ttgir files +for dir in "$DUMP_DIR"/*; do + if [ -d "$dir" ]; then + find "$dir" -type f ! -name "*.ttgir" -delete + fi +done + +echo "TTGIR files dumped to $DUMP_DIR" diff --git a/third_party/mthreads/proton/test/CMakeLists.txt b/third_party/mthreads/proton/test/CMakeLists.txt new file mode 100644 index 0000000000..5b05b08944 --- /dev/null +++ b/third_party/mthreads/proton/test/CMakeLists.txt @@ -0,0 +1,3 @@ +if(TRITON_BUILD_UT) + add_subdirectory(unittest) +endif() diff --git a/third_party/mthreads/proton/test/conftest.py b/third_party/mthreads/proton/test/conftest.py new file mode 100644 index 0000000000..722f9de90d --- /dev/null +++ b/third_party/mthreads/proton/test/conftest.py @@ -0,0 +1,12 @@ +import pytest + + +@pytest.fixture +def fresh_knobs(): + from triton._internal_testing import _fresh_knobs_impl + + fresh_function, reset_function = _fresh_knobs_impl() + try: + yield fresh_function() + finally: + reset_function() diff --git a/third_party/mthreads/proton/test/examples/cuda.json b/third_party/mthreads/proton/test/examples/cuda.json new file mode 100644 index 0000000000..bcf433d605 --- /dev/null +++ b/third_party/mthreads/proton/test/examples/cuda.json @@ -0,0 +1,86 @@ +[ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "count": 10, + "device_id": "1", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + }, + { + "children": [], + "frame": { + "name": "foo2", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "2", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e7 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "CUDA": { + "0": { + "arch": "89", + "bus_width": 384, + "clock_rate": 2625000, + "memory_clock_rate": 10501000, + "num_sms": 128 + }, + "1": { + "arch": "90", + "bus_width": 6144, + "clock_rate": 1980000, + "memory_clock_rate": 2619000, + "num_sms": 132 + }, + "2": { + "arch": "100", + "bus_width": 6144, + "clock_rate": 1700000, + "memory_clock_rate": 2619000, + "num_sms": 148 + } + } + } +] diff --git a/third_party/mthreads/proton/test/examples/frame.json b/third_party/mthreads/proton/test/examples/frame.json new file mode 100644 index 0000000000..cd671c9dff --- /dev/null +++ b/third_party/mthreads/proton/test/examples/frame.json @@ -0,0 +1,58 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "/home/user/projects/example.py/test.py:1@foo", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "test0" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "test1" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/mthreads/proton/test/examples/hip.json b/third_party/mthreads/proton/test/examples/hip.json new file mode 100644 index 0000000000..fd52f96ea9 --- /dev/null +++ b/third_party/mthreads/proton/test/examples/hip.json @@ -0,0 +1,86 @@ +[ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "1", + "device_type": "HIP", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + }, + { + "children": [], + "frame": { + "name": "foo2", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "2", + "device_type": "HIP", + "time (ns)": 204800, + "flops8": 1e12, + "bytes": 1e9 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + }, + "1": { + "arch": "gfx942", + "bus_width": 8192, + "clock_rate": 2100000, + "memory_clock_rate": 1200000, + "num_sms": 304 + }, + "2": { + "arch": "gfx950", + "bus_width": 8192, + "clock_rate": 2200000, + "memory_clock_rate": 1900000, + "num_sms": 256 + } + } + } +] diff --git a/third_party/mthreads/proton/test/examples/leaf_nodes.json b/third_party/mthreads/proton/test/examples/leaf_nodes.json new file mode 100644 index 0000000000..5930664dd2 --- /dev/null +++ b/third_party/mthreads/proton/test/examples/leaf_nodes.json @@ -0,0 +1,168 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_2_2", + "type": "function" + }, + "metrics": { + "count": 402, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 78190414 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_3_1", + "type": "function" + }, + "metrics": { + "count": 502, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 24125138 + } + } + ], + "frame": { + "name": "kernel_1_2_1", + "type": "function" + }, + "metrics": { + "bytes": 3997237248, + "flops": 1534939103232 + } + } + ], + "frame": { + "name": "kernel_1_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_2_2", + "type": "function" + }, + "metrics": { + "count": 120, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 23174888 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_3_1", + "type": "function" + }, + "metrics": { + "count": 149, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 1040322 + } + } + ], + "frame": { + "name": "kernel_2_2_1", + "type": "function" + }, + "metrics": { + "bytes": 58589184, + "flops": 4999610368 + } + } + ], + "frame": { + "name": "kernel_2_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_2", + "type": "function" + }, + "metrics": { + "count": 480, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 93036508 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "count": 599, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 6306402 + } + } + ], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "bytes": 529956864, + "flops": 67834478592 + } + } + ], + "frame": { + "name": "kernel_3_1_1", + "type": "function" + }, + "metrics": {} + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "flops": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/mthreads/proton/test/examples/triton.json b/third_party/mthreads/proton/test/examples/triton.json new file mode 100644 index 0000000000..2a29ee358b --- /dev/null +++ b/third_party/mthreads/proton/test/examples/triton.json @@ -0,0 +1,73 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "cuda_kernel", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 4064 + } + } + ], + "frame": { + "name": "__proton_launch_metadata", + "type": "function" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "triton_kernel", + "type": "function" + }, + "metrics": { + "bytes": 2.0, + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 1664 + } + } + ], + "frame": { + "name": "scope", + "type": "function" + }, + "metrics": { + "cpu_time (ns)": 12345 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "time (ns)": 0 + } + }, + { + "CUDA": { + "0": { + "arch": "86", + "bus_width": 128, + "clock_rate": 1140000, + "memory_clock_rate": 5501000, + "num_sms": 16 + } + } + } +] diff --git a/third_party/mthreads/proton/test/helper.py b/third_party/mthreads/proton/test/helper.py new file mode 100644 index 0000000000..263ee31318 --- /dev/null +++ b/third_party/mthreads/proton/test/helper.py @@ -0,0 +1,40 @@ +import triton.profiler as proton + +import torch +import sys + +from helper_kernels import custom_add, matmul_kernel + + +def main(): + a = torch.zeros(1, device="cuda") + with proton.scope("test"): + custom_add[(1, )](a) + + +def test_main(): + main() + + +def matmul(): + a = torch.randn((32, 32), device="cuda", dtype=torch.float16) + b = torch.randn((32, 32), device="cuda", dtype=torch.float16) + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + matmul_kernel[(1, )]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + 128, 256, 64, 8) + return c + + +if __name__ == "__main__": + if sys.argv[1] == "test": + main() + elif sys.argv[1] == "test_matmul": + matmul() diff --git a/third_party/mthreads/proton/test/helper_kernels.py b/third_party/mthreads/proton/test/helper_kernels.py new file mode 100644 index 0000000000..0521b95957 --- /dev/null +++ b/third_party/mthreads/proton/test/helper_kernels.py @@ -0,0 +1,45 @@ +import triton.language as tl +import triton + + +@triton.jit +def custom_add(a_ptr): + tl.store(a_ptr, 1.0) + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/third_party/mthreads/proton/test/override_helper.py b/third_party/mthreads/proton/test/override_helper.py new file mode 100644 index 0000000000..7df7f72d6e --- /dev/null +++ b/third_party/mthreads/proton/test/override_helper.py @@ -0,0 +1,54 @@ +import torch + +import triton +import triton.language as tl +import triton.profiler as proton +import pathlib +import sys + +from typing import NamedTuple + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + BLOCK_SIZE = args["BLOCK_SIZE"] + return {"name": f"add_{BLOCK_SIZE}"} + + +@triton.jit(launch_metadata=metadata_fn) +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor, path): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + tmp_path = pathlib.Path(path) + temp_file = tmp_path / "test_override.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation") + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + proton.finalize() + return output + + +size = 98432 +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) +output_torch = x + y +output_triton = add(x, y, sys.argv[-1]) diff --git a/third_party/mthreads/proton/test/test_api.py b/third_party/mthreads/proton/test/test_api.py new file mode 100644 index 0000000000..37f6c35f4d --- /dev/null +++ b/third_party/mthreads/proton/test/test_api.py @@ -0,0 +1,435 @@ +""" +Test module for proton's Python API. +No GPU kernel should be declared in this test. +Profile correctness tests involving GPU kernels should be placed in `test_profile.py`. +""" + +import pytest +import json +import triton.profiler as proton +import pathlib +from triton.profiler.hooks.hook import HookManager +from triton.profiler.hooks.launch import LaunchHook +from triton.profiler.hooks.instrumentation import InstrumentationHook +from triton._internal_testing import is_hip + + +def test_profile_single_session(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert temp_file0.exists() + + temp_file1 = tmp_path / "test_profile1.hatchet" + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + proton.activate(session_id1) + proton.deactivate(session_id1) + proton.finalize(session_id1) + assert session_id1 == session_id0 + 1 + assert temp_file1.exists() + + session_id2 = proton.start("test") + proton.activate(session_id2) + proton.deactivate(session_id2) + proton.finalize() + assert session_id2 == session_id1 + 1 + assert pathlib.Path("test.hatchet").exists() + pathlib.Path("test.hatchet").unlink() + + +def test_profile_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + proton.start(str(temp_file0.with_suffix(""))) + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert temp_file0.exists() + assert temp_file1.exists() + + temp_file2 = tmp_path / "test_profile2.hatchet" + session_id2 = proton.start(str(temp_file2.with_suffix(""))) + temp_file3 = tmp_path / "test_profile3.hatchet" + session_id3 = proton.start(str(temp_file3.with_suffix(""))) + proton.deactivate(session_id2) + proton.deactivate(session_id3) + proton.finalize() + assert temp_file2.exists() + assert temp_file3.exists() + + +def test_profile_mode(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + if is_hip(): + try: + proton.start(str(temp_file0.with_suffix("")), mode="pcsampling") + except Exception as e: + assert "RoctracerProfiler: unsupported mode: pcsampling" in str(e) + finally: + proton.finalize() + else: + import os + import pytest + + if os.environ.get("PROTON_SKIP_PC_SAMPLING_TEST", "0") == "1": + pytest.skip("PC sampling test is disabled") + + # Two sessions with the same mode can coexist + proton.start(str(temp_file0.with_suffix("")), mode="pcsampling") + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix("")), mode="pcsampling") + proton.finalize() + assert temp_file1.exists() + + # Two sessions with different modes cannot coexist + try: + proton.start(str(temp_file0.with_suffix("")), mode="pcsampling") + proton.start(str(temp_file1.with_suffix(""))) + except Exception as e: + assert "Cannot add a session with the same profiler but a different mode than existing sessions" in str(e) + finally: + proton.finalize() + + # Two sessions with different modes cannot coexist even if the first session is deactivated. + # In proton, once we deactivate a session, its profiler is not stopped, so changing the profiler mode is not allowed + # The only way to start a session with a different mode is to finalize all existing sessions first. + try: + session_id = proton.start(str(temp_file0.with_suffix("")), mode="pcsampling") + proton.deactivate(session_id) + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix(""))) + except Exception as e: + assert "Cannot add a session with the same profiler but a different mode than existing sessions" in str(e) + finally: + proton.finalize() + + +def test_profile_decorator(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_profile_decorator.hatchet" + + @proton.profile(name=str(temp_file.with_suffix(""))) + def foo0(a, b): + return a + b + + foo0(1, 2) + proton.finalize() + assert temp_file.exists() + + @proton.profile + def foo1(a, b): + return a + b + + foo1(1, 2) + proton.finalize() + default_file = pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet") + assert default_file.exists() + default_file.unlink() + + +def test_scope(tmp_path: pathlib.Path): + # Scope can be annotated even when profiling is off + with proton.scope("test"): + pass + + temp_file = tmp_path / "test_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test"): + pass + + @proton.scope("test") + def foo(): + pass + + foo() + + proton.enter_scope("test") + proton.exit_scope() + + proton.enter_scope("test0") + proton.exit_scope("test0") + + proton.finalize() + assert temp_file.exists() + + +def test_hook(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_hook.hatchet" + session_id0 = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.activate(session_id0) + proton.activate(session_id0) + assert len( + HookManager.active_hooks) == 1, ("Activate a session multiple times should maintain a single instance of hook") + assert list(HookManager.session_hooks[session_id0].values())[0] is True + proton.deactivate(session_id0) + assert list(HookManager.session_hooks[session_id0].values())[0] is False + assert len(HookManager.active_hooks) == 0 + # Deactivate a session multiple times should not raise an error + proton.deactivate(session_id0) + proton.finalize(None) + assert temp_file.exists() + + +def test_hook_manager(tmp_path: pathlib.Path): + # Launch hook is a singleton + HookManager.register(LaunchHook(), 0) + HookManager.register(LaunchHook(), 0) + assert len(HookManager.active_hooks) == 1 + assert isinstance(HookManager.active_hooks[0], LaunchHook) + assert HookManager.session_hooks[0][HookManager.active_hooks[0]] is True + + # Only unregister one session + HookManager.register(LaunchHook(), 1) + HookManager.unregister(0) + assert len(HookManager.active_hooks) == 1 + HookManager.unregister(1) + assert len(HookManager.active_hooks) == 0 + + # Heterogenous hooks + HookManager.register(InstrumentationHook(""), 2) + HookManager.register(LaunchHook(), 2) + assert len(HookManager.active_hooks) == 2 + # Launch hook has a higher priority + assert isinstance(HookManager.active_hooks[0], LaunchHook) + assert isinstance(HookManager.active_hooks[1], InstrumentationHook) + assert HookManager.session_hooks[2][HookManager.active_hooks[0]] is True + assert HookManager.session_hooks[2][HookManager.active_hooks[1]] is True + HookManager.unregister() + assert len(HookManager.active_hooks) == 0 + + +def test_scope_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + with proton.scope("test0", {"a": 1.0}): + pass + + @proton.scope("test1", {"a": 1.0}) + def foo(): + pass + + foo() + + # After deactivation, the metrics should be ignored + proton.deactivate(session_id) + proton.enter_scope("test2", metrics={"a": 1.0}) + proton.exit_scope() + + # Metrics should be recorded again after reactivation + proton.activate(session_id) + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() + + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() + + # exit_scope can also take metrics + proton.enter_scope("test4") + proton.exit_scope(metrics={"b": 1.0}) + + proton.finalize() + + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 4 + for child in data[0]["children"]: + if child["frame"]["name"] == "test3": + assert child["metrics"]["a"] == 2.0 + elif child["frame"]["name"] == "test4": + assert child["metrics"]["b"] == 1.0 + + +def test_scope_metrics_invalid(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + proton.start(str(temp_file.with_suffix(""))) + + error = None + + try: + with proton.scope("test0", {"a": 1.0}): + pass + + with proton.scope("test0", {"a": 1}): + pass + except Exception as e: + error = str(e) + finally: + proton.finalize() + + assert error is not None and "Metric value type mismatch for valueId 0 (a): current=double, new=int64_t" in error + + +def test_scope_properties(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_properties.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # Properties do not aggregate + proton.enter_scope("test0", metrics={"a (pty)": 1.0}) + proton.exit_scope() + + proton.enter_scope("test0", metrics={"a (pty)": 1.0}) + proton.exit_scope() + + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + for child in data[0]["children"]: + if child["frame"]["name"] == "test0": + assert child["metrics"]["a"] == 1.0 + + +def test_scope_exclusive(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_exclusive.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # metric a only appears in the outermost scope + # metric b only appears in the innermost scope + # both metrics do not appear in the root scope + with proton.scope("test0", metrics={"a (exc)": 1}): + with proton.scope("test1", metrics={"b (exc)": 1}): + pass + + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + root_metrics = data[0]["metrics"] + assert len(root_metrics) == 0 + test0_frame = data[0]["children"][0] + test0_metrics = test0_frame["metrics"] + assert len(test0_metrics) == 1 + assert test0_metrics["a"] == 1 + test1_frame = test0_frame["children"][0] + test1_metrics = test1_frame["metrics"] + assert len(test1_metrics) == 1 + assert test1_metrics["b"] == 1 + + +def test_state(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_state.hatchet" + proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.enter_state("state") + proton.enter_scope("test1", metrics={"a": 1.0}) + proton.exit_scope() + proton.exit_state() + proton.exit_scope() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + # test0->test1->state + assert len(data[0]["children"]) == 1 + child = data[0]["children"][0] + assert child["frame"]["name"] == "test0" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "test1" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "state" + assert child["metrics"]["a"] == 1.0 + + +def test_context_depth(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_context_depth.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + assert proton.context.depth(session_id) == 0 + proton.enter_scope("test0") + assert proton.context.depth(session_id) == 1 + proton.enter_scope("test1") + assert proton.context.depth(session_id) == 2 + proton.exit_scope() + assert proton.context.depth(session_id) == 1 + proton.exit_scope() + assert proton.context.depth(session_id) == 0 + proton.finalize() + + +def test_throw(tmp_path: pathlib.Path): + # Catch an exception thrown by c++ + session_id = 100 + temp_file = tmp_path / "test_throw.hatchet" + activate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error + + +@pytest.mark.parametrize("disable", [True, False]) +def test_profile_disable(disable, fresh_knobs, tmp_path: pathlib.Path): + fresh_knobs.proton.disable = disable + temp_file = tmp_path / "test_profile_disable.hatchet" + proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.exit_scope() + proton.finalize() + if disable: + assert not temp_file.exists() + else: + assert temp_file.exists() + + +def test_finalize_within_scope(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_finalize_within_scope.hatchet" + session_id0 = proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0"): + assert proton.context.depth(session_id0) == 1 + proton.finalize() + assert temp_file.exists() + temp_file1 = tmp_path / "test_finalize_within_scope1.hatchet" + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + depth = proton.context.depth(session_id1) + assert depth == 0 + proton.finalize() + + +def test_data_api(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_data_api.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.exit_scope() + proton.deactivate(session_id) + json_data = proton.data.get(session_id) + assert json_data is not None + msgpack_data = proton.data.get_msgpack(session_id) + assert isinstance(msgpack_data, bytes) + is_complete = proton.data.is_phase_complete(session_id, 0) + assert is_complete is False + next_phase = proton.data.advance_phase(session_id) + assert next_phase == 1 + is_complete = proton.data.is_phase_complete(session_id, 1) + assert is_complete is False + + # Even if a phase has no GPU activity records, flushing should still mark it + # as flushed. + proton.activate(session_id) + next_phase = proton.data.advance_phase(session_id) + assert next_phase == 2 + proton.deactivate(session_id, flushing=True) + assert proton.data.is_phase_complete(session_id, 1) is True + assert proton.data.is_phase_complete(session_id, 2) is False + + # Test clear and clear_up_to_phase + proton.data.clear(session_id, phase=0) + proton.data.clear(session_id, phase=2, clear_up_to_phase=True) + + proton.finalize() diff --git a/third_party/mthreads/proton/test/test_cmd.py b/third_party/mthreads/proton/test/test_cmd.py new file mode 100644 index 0000000000..7e1d438d3b --- /dev/null +++ b/third_party/mthreads/proton/test/test_cmd.py @@ -0,0 +1,30 @@ +import pytest +import subprocess +import json +import pathlib + + +def test_help(): + # Only check if the viewer can be invoked + subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) + + +@pytest.mark.parametrize("mode", ["script", "python", "pytest"]) +def test_exec(mode, tmp_path: pathlib.Path): + file_path = __file__ + helper_file = file_path.replace("test_cmd.py", "helper.py") + temp_file = tmp_path / "test_exec.hatchet" + name = str(temp_file.with_suffix("")) + if mode == "script": + subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + elif mode == "python": + subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) + elif mode == "pytest": + subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) + with temp_file.open() as f: + data = json.load(f, ) + kernels = data[0]["children"] + assert len(kernels) == 2 + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" diff --git a/third_party/mthreads/proton/test/test_instrumentation.py b/third_party/mthreads/proton/test/test_instrumentation.py new file mode 100644 index 0000000000..9df27ac6e4 --- /dev/null +++ b/third_party/mthreads/proton/test/test_instrumentation.py @@ -0,0 +1,1003 @@ +import json +import pathlib + +from typing import NamedTuple, Tuple, Optional + +import pytest +import torch + +import triton +import triton.language as tl +import triton.profiler as proton +import triton.profiler.language as pl +from triton._internal_testing import ( + is_cuda, + is_hip, + is_hip_cdna2, + is_hip_cdna4, + supports_tma, + supports_ws, +) +from triton.tools.tensor_descriptor import TensorDescriptor + +pl.enable_semantic("triton") + +# Skip all tests if the AMD GPU version is not supported +pytestmark = pytest.mark.skipif(is_hip_cdna2(), reason="old AMD GPUs are not supported") + +HAS_WARP_SPECIALIZE = supports_ws() and supports_tma() + + +@pytest.mark.parametrize( + "mode", + [ + "default", + "default:metric_type=cycle", + "default:metric_type=cycle:buffer_size=4096", + "mma", + ], +) +def test_mode_str(mode, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_mode_str.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", mode=mode) + proton.finalize() + + +@pytest.mark.parametrize( + "mode", + [ + proton.mode.Default(), + proton.mode.Default(metric_type="cycle"), + proton.mode.Default(metric_type="cycle", buffer_size=4096), + proton.mode.MMA(), + ], +) +def test_mode_obj(mode, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_mode_simple.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", mode=mode) + proton.finalize() + + +def test_jit(tmp_path): + + @triton.jit + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook_instrumentation.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation") + foo[(1, )](x, 1, y, num_warps=4) + device = triton.runtime.driver.active.get_current_device() + assert len(foo.device_caches[device][0]) == 1, "Kernel should be cached" + proton.finalize() + foo[(1, )](x, 1, y, num_warps=4) + assert (len(foo.device_caches[device][0]) == 2), "Instrumented and uninstrumented kernels both should be cached" + + +@pytest.mark.parametrize("method", ["operator", "context_manager"]) +def test_record(method, fresh_knobs, tmp_path: pathlib.Path): + fresh_knobs.compilation.disable_line_info = False + + from contextlib import contextmanager + + @contextmanager + def instrumentation(file_path): + proton.hooks.InstrumentationHook.enable_host_buffer = True + proton.start(str(file_path.with_suffix("")), backend="instrumentation") + try: + yield + finally: + proton.hooks.InstrumentationHook.enable_host_buffer = False + proton.finalize() + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + METHOD: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + if METHOD == "operator": + pl.enter_scope("load0") + y = tl.load(y_ptr + offsets, mask=mask) + pl.exit_scope("load0") + else: + with pl.scope("load0"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_record.hatchet" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + with instrumentation(temp_file): + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, METHOD=method) + # FIXME(fywkevin): have a dedicated place to put those decoding related constants + payload_offset = int.from_bytes( + proton.hooks.InstrumentationHook.host_buffer[12:16].numpy().tobytes(), + "little", + ) + host_buffer = proton.hooks.InstrumentationHook.host_buffer[payload_offset:] + preamble = host_buffer[0:4] + assert int.from_bytes(preamble.numpy().tobytes(), "little") == 0xDEADBEEF + header_size = 40 + metadata_size = header_size + pgm.metadata.num_warps * 4 + start_tag = host_buffer[metadata_size:metadata_size + 4] + start_clock = host_buffer[metadata_size + 4:metadata_size + 8] + end_tag = host_buffer[metadata_size + 8:metadata_size + 12] + end_clock = host_buffer[metadata_size + 12:metadata_size + 16] + assert int.from_bytes(start_tag.numpy().tobytes(), "little") & 0xFFFFF800 == 0 + assert (int.from_bytes(end_tag.numpy().tobytes(), "little") & 0xFFFFF800 == 0x80000000) + start_clock_val = int.from_bytes(start_tag.numpy().tobytes(), "little") & 0x7FF << 32 | int.from_bytes( + start_clock.numpy().tobytes(), "little") + end_clock_val = int.from_bytes(end_tag.numpy().tobytes(), "little") & 0x7FF << 32 | int.from_bytes( + end_clock.numpy().tobytes(), "little") + assert end_clock_val > start_clock_val + + # instrumentation context has finalized, now validate assembly + ttir = pgm.asm["ttir"] + assert "proton.record start" in ttir + assert "proton.record end" in ttir + + # check ttir line info + start_loc = None + end_loc = None + for line in ttir.split("\n"): + if "proton.record start" in line: + start_loc = line.split("loc(")[1].split(")")[0] + elif "proton.record end" in line: + end_loc = line.split("loc(")[1].split(")")[0] + elif start_loc and f"#loc{start_loc}" in line: + assert "test_instrumentation.py" in line + elif end_loc and f"#loc{end_loc}" in line: + assert "test_instrumentation.py" in line + + assert start_loc is not None and end_loc is not None + + # check llir line info + llir_lines = pgm.asm["llir"].splitlines() + clock_instr = "clock" if is_cuda() else "memtime" + clock_loc = None + for line in llir_lines: + if clock_instr not in line or "!dbg" not in line: + continue + suffix = line.split("!dbg ")[1] + clock_loc = suffix.split(",")[0].split()[0] + break + assert clock_loc is not None + loc_line = next( + (line for line in llir_lines if clock_loc in line and "DILocation" in line), + None, + ) + assert loc_line is not None + assert "line: " in loc_line and "line: 0" not in loc_line + + +def test_select_ids(tmp_path: pathlib.Path): + from contextlib import contextmanager + + select_ids = [0, 2] + mode = proton.mode.Default( + sampling_strategy="selective", + sampling_options=",".join(str(i) for i in select_ids), + granularity="warp", + ) + + @contextmanager + def instrumentation(file_path): + proton.hooks.InstrumentationHook.enable_host_buffer = True + proton.start( + str(file_path.with_suffix("")), + backend="instrumentation", + mode=mode, + ) + try: + yield + finally: + proton.hooks.InstrumentationHook.enable_host_buffer = False + proton.finalize() + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_ops"): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_select_ids.hatchet" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + + warp_indices = [] + + with instrumentation(temp_file): + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=4) + uid_num_offset = 36 + uid_vec_offset = 40 + uid_num = int.from_bytes( + proton.hooks.InstrumentationHook.host_buffer[uid_num_offset:uid_num_offset + 4].numpy().tobytes(), + "little", + ) + assert uid_num == len(select_ids) + for i in range(uid_num): + offset = uid_vec_offset + i * 4 + warp_id = int.from_bytes( + proton.hooks.InstrumentationHook.host_buffer[offset:offset + 4].numpy().tobytes(), + "little", + ) + warp_indices.append(warp_id) + assert sorted(warp_indices) == select_ids + + +@pytest.mark.parametrize("hook", ["triton", None]) +def test_tree(tmp_path: pathlib.Path, hook): + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + BLOCK_SIZE = args["BLOCK_SIZE"] + return {"name": f"add_{BLOCK_SIZE}"} + + @triton.jit(launch_metadata=metadata_fn) + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_ops"): + with pl.scope("load_x"): + x = tl.load(x_ptr + offsets, mask=mask) + with pl.scope("load_y"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_tree.hatchet" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", hook=hook) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + proton.finalize() + + with open(temp_file, "rb") as f: + data = json.load(f) + if hook: + assert "add_1024" == data[0]["children"][0]["frame"]["name"] + kernel_frame = data[0]["children"][0]["children"][0] + load_ops = kernel_frame["children"][0] + assert "load_ops" in load_ops["frame"]["name"] + assert ("load_x" in load_ops["children"][0]["frame"]["name"] + or "load_x" in load_ops["children"][1]["frame"]["name"]) + assert ("load_y" in load_ops["children"][0]["frame"]["name"] + or "load_y" in load_ops["children"][1]["frame"]["name"]) + assert load_ops["children"][0]["metrics"]["cycles"] > 0 + assert load_ops["children"][0]["metrics"]["normalized_cycles"] > 0 + assert load_ops["children"][1]["metrics"]["cycles"] > 0 + assert load_ops["children"][1]["metrics"]["normalized_cycles"] > 0 + + +def test_trace(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_ops"): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + @triton.jit + def sub_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_ops"): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x - y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_trace.chrome_trace" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", data="trace") + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + sub_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + proton.finalize() + + with open(temp_file, "rb") as f: + data = json.load(f) + events = data["traceEvents"] + assert events[0]["name"] == "kernel" + assert events[0]["cat"] == "add_kernel" + assert events[1]["name"] == "load_ops" + assert events[1]["cat"] == "add_kernel" + assert events[2]["name"] == "kernel" + assert events[2]["cat"] == "sub_kernel" + assert events[3]["name"] == "load_ops" + assert events[3]["cat"] == "sub_kernel" + + +def test_multi_session(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_x"): + x = tl.load(x_ptr + offsets, mask=mask) + with pl.scope("load_y"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file_inst = tmp_path / "test_tree_inst.hatchet" + temp_file_driver = tmp_path / "test_tree_driver.hatchet" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + session_id0 = proton.start(str(temp_file_inst.with_suffix("")), backend="instrumentation") + session_id1 = proton.start(str(temp_file_driver.with_suffix(""))) + proton.deactivate(session_id0) + proton.deactivate(session_id1) + proton.activate() + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + proton.finalize() + + temp_file_restart = tmp_path / "test_tree_restart.hatchet" + session_id0 = proton.start(str(temp_file_restart.with_suffix("")), backend="instrumentation") + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=1) + proton.finalize() + + with open(temp_file_inst, "rb") as f: + data = json.load(f) + kernel_frame = data[0]["children"][0] + assert "add_kernel" == kernel_frame["frame"]["name"] + assert "cycles" in kernel_frame["children"][0]["metrics"] + + with open(temp_file_driver, "rb") as f: + data = json.load(f) + kernel_frame = data[0]["children"][0] + assert "add_kernel" == kernel_frame["frame"]["name"] + assert "time (ns)" in kernel_frame["metrics"] + + with open(temp_file_restart, "rb") as f: + data = json.load(f) + kernel_frame = data[0]["children"][0] + assert "add_kernel" == kernel_frame["frame"]["name"] + assert "cycles" in kernel_frame["children"][0]["metrics"] + + +def test_autotune(tmp_path: pathlib.Path): + + def metadata_fn( + grid: tuple, + metadata: NamedTuple, + args: dict, + ): + BLOCK_SIZE = args["BLOCK_SIZE"] + return { + "name": f"add_{BLOCK_SIZE}", + } + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=1), + triton.Config({"BLOCK_SIZE": 512}, num_warps=1), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=1), + ], + key=["n_elements"], + ) + @triton.jit(launch_metadata=metadata_fn) + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_x"): + x = tl.load(x_ptr + offsets, mask=mask) + with pl.scope("load_y"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 2048 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + temp_file = tmp_path / "test_autotune.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", hook="triton") + add_kernel[grid](x, y, output, n_elements) + proton.finalize() + + # Check all names exist in the output + with open(temp_file, "rb") as f: + data = json.load(f) + names = [frame["frame"]["name"] for frame in data[0]["children"]] + assert "add_256" in names + assert "add_512" in names + assert "add_1024" in names + + +def test_warp_spec(tmp_path: pathlib.Path): + if not supports_tma() or not supports_ws(): + pytest.skip("target backend does not support warp specialization and TMA") + + @triton.jit + def matmul_kernel_tma(a_desc, b_desc, c_desc, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + WARP_SPECIALIZE: tl.constexpr, # + ): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in tl.range(k_tiles, warp_specialize=WARP_SPECIALIZE): + pl.enter_scope("loop") + offs_k = k * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + pl.exit_scope("loop") + + c = accumulator.to(dtype) + + offs_cm = pid_m * BLOCK_SIZE_M + offs_cn = pid_n * BLOCK_SIZE_N + c_desc.store([offs_cm, offs_cn], c) + + def matmul_tma(a, b, warp_specialize: bool): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + a_desc = TensorDescriptor(a, a.shape, a.stride(), [128, 128]) + b_desc = TensorDescriptor(b, b.shape, b.stride(), [256, 128]) + c_desc = TensorDescriptor(c, c.shape, c.stride(), [128, 256]) + + def grid(META): + BLOCK_M = 128 + BLOCK_N = 256 + return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) + + matmul_kernel_tma[grid]( + a_desc, + b_desc, + c_desc, # + M, + N, + K, # + BLOCK_SIZE_M=128, # + BLOCK_SIZE_N=256, # + BLOCK_SIZE_K=128, # + GROUP_SIZE_M=8, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + WARP_SPECIALIZE=warp_specialize, # + num_stages=2, # + num_warps=8, + ) + return c + + mode = proton.mode.Default(metric_type="cycle", optimizations="clock32") + temp_file = tmp_path / "test_warpspec.hatchet" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", mode=mode) + M, N, K = 512, 512, 512 + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + b = b.T.contiguous() + + matmul_tma(a, b, warp_specialize=HAS_WARP_SPECIALIZE) + proton.finalize() + + with open(temp_file, "rb") as f: + data = json.load(f) + kernel = data[0]["children"][0] + assert kernel["children"][0]["frame"]["name"] == "loop" + assert kernel["children"][0]["metrics"]["cycles"] > 0 + assert kernel["frame"]["name"] == "matmul_kernel_tma" + + +def test_timeline(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_timeline.chrome_trace" + mode = proton.mode.Default(metric_type="cycle", optimizations="time_shift") + proton.start( + str(temp_file.with_suffix("")), + data="trace", + backend="instrumentation", + mode=mode, + ) + + @triton.jit + def foo(x, y, size: tl.constexpr): + pl.enter_scope("entire") + offs = tl.arange(0, size) + pl.enter_scope("load") + x = tl.load(x + offs) + x = x + 1 + pl.exit_scope("load") + pl.enter_scope("store") + tl.store(y + offs, x) + pl.exit_scope("store") + pl.exit_scope("entire") + + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + trace_events = data["traceEvents"] + assert len(trace_events) == 12 + assert trace_events[-1]["tid"][0:4] == "warp" + assert trace_events[-1]["args"]["call_stack"][-1] == "foo" + assert trace_events[-1]["args"]["call_stack"][-2] == "test" + + +@pytest.mark.skipif(is_hip_cdna4(), reason="nondeterministic failure") +def test_globaltime(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_globaltime.chrome_trace" + mode = proton.mode.Default( + metric_type="cycle", + optimizations="clock32,time_shift", + sampling_strategy="selective", + sampling_options="0", + ) + proton.start( + str(temp_file.with_suffix("")), + data="trace", + backend="instrumentation", + mode=mode, + ) + + @triton.jit() + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pl.enter_scope("elementwise_add_kernel") + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + pl.exit_scope("elementwise_add_kernel") + + size = 1024 * 2000 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + output = torch.empty_like(x) + n_elements = output.numel() + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(n_elements, BLOCK_SIZE), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE, num_warps=16) + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + trace_events = data["traceEvents"] + target = sorted( + [event for event in trace_events if "Core0 " in event["pid"]], + key=lambda x: x["ts"], + ) + s = len(target) + assert s > 1 + ts_diff = target[s - 1]["ts"] - target[0]["ts"] + assert ts_diff >= target[0]["dur"] + + +@pytest.mark.skipif(is_hip(), reason="not stable overhead numbers on AMD GPUs") +def test_overhead(tmp_path: pathlib.Path): + temp_file_cycles = tmp_path / "test_overhead.hatchet" + temp_file_time = tmp_path / "test_overhead_time.hatchet" + + @triton.jit() + def kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, LOOP: tl.constexpr): + pl.enter_scope("kernel") + for _ in range(16): + if LOOP: + pl.enter_scope("loop") + x = tl.load(x_ptr + tl.arange(0, BLOCK_SIZE)) + tl.store(y_ptr + tl.arange(0, BLOCK_SIZE), x + 1) + if LOOP: + pl.exit_scope("loop") + pl.exit_scope("kernel") + + BLOCK_SIZE = 256 + x = torch.zeros(BLOCK_SIZE, device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + + def bench(): + with proton.scope("single"): + kernel[(1024, )](x, y, BLOCK_SIZE, False) + with proton.scope("loop"): + kernel[(1024, )](x, y, BLOCK_SIZE, True) + + # warmup + bench() + + proton.start(str(temp_file_time.with_suffix("")), ) + + with proton.scope("session0"): + bench() + + proton.start(str(temp_file_cycles.with_suffix("")), backend="instrumentation", + mode=proton.mode.Default(metric_type="cycle", buffer_size=4096)) + + with proton.scope("session1"): + bench() + proton.finalize() + + with temp_file_time.open("rb") as f: + data = json.load(f) + root = data[0] + + def session_kernel_time(session_name: str) -> Tuple[int, int]: + session_node = next(child for child in root["children"] if child["frame"]["name"] == session_name) + single_node = next(child for child in session_node["children"] if child["frame"]["name"] == "single") + loop_node = next(child for child in session_node["children"] if child["frame"]["name"] == "loop") + kernel_node = single_node["children"][0] + single_time = kernel_node["metrics"]["time (ns)"] + kernel_node = loop_node["children"][0] + loop_time = kernel_node["metrics"]["time (ns)"] + return single_time, loop_time + + session0_single_time, session0_loop_time = session_kernel_time("session0") + session1_single_time, session1_loop_time = session_kernel_time("session1") + single_threshold = 1.2 if is_cuda() else 1.5 + loop_threshold = 2.0 if is_cuda() else 3.0 + assert session1_single_time / session0_single_time < single_threshold, "Simple kernel overhead too high" + assert session1_loop_time / session0_loop_time < loop_threshold, "Loop kernel overhead too high" + + +def test_gmem_buffer(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_ops"): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 512 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_gmem_buffer.chrome_trace" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + mode = proton.mode.Default(buffer_type="global") + proton.start( + str(temp_file.with_suffix("")), + backend="instrumentation", + data="trace", + mode=mode, + ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=2) + proton.finalize() + + with open(temp_file, "rb") as f: + data = json.load(f) + events = data["traceEvents"] + + # Assert we have exactly 4 events (2 warps × 2 scopes) + assert len(events) == 4 + + # Assert all events have the expected common fields + for event in events: + assert "ts" in event + assert "dur" in event + assert event["dur"] > 0 + + # Assert we have 2 kernel events and 2 load_ops events + kernel_events = [e for e in events if e["name"] == "kernel"] + load_ops_events = [e for e in events if e["name"] == "load_ops"] + assert len(kernel_events) == 2 + assert len(load_ops_events) == 2 + + # Assert we have events from both warps + warp0_events = [e for e in events if "warp 0" in e["tid"]] + warp1_events = [e for e in events if "warp 1" in e["tid"]] + assert len(warp0_events) == 2 + assert len(warp1_events) == 2 + + +def test_event_args(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + temp_file = tmp_path / "test_block_metadata.chrome_trace" + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", data="trace") + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=2) + proton.finalize() + + with open(temp_file, "rb") as f: + data = json.load(f) + events = data["traceEvents"] + + # Verify we have events + assert len(events) > 0 + + # Verify each event has the required metadata in args + for event in events: + assert "args" in event + args = event["args"] + + assert "Init Time (ns)" in args + assert "Post Final Time (ns)" in args + assert "Finalization Time (ns)" in args + + # Verify timing values are reasonable + init_time = args["Init Time (ns)"] + post_final_time = args["Post Final Time (ns)"] + finalization_time = args["Finalization Time (ns)"] + + assert init_time >= 0 + assert post_final_time >= 0 + assert finalization_time >= 0 + + +def test_threaded_kernel_call(tmp_path: pathlib.Path): + + import threading + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + with pl.scope("kernel"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + size = 256 + x = torch.rand(size, device="cuda") + y = torch.rand(size, device="cuda") + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + + temp_file = tmp_path / "test_threaded.chrome_trace" + proton.start( + str(temp_file.with_suffix("")), + backend="instrumentation", + data="trace", + ) + + exception_holder = [] + + def run_kernel(): + try: + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + except Exception as e: + exception_holder.append(e) + + thread = threading.Thread(target=run_kernel) + thread.start() + thread.join() + + proton.finalize() + + assert len(exception_holder) == 0, f"Kernel raised exception: {exception_holder[0] if exception_holder else None}" + + with open(temp_file, "rb") as f: + data = json.load(f) + events = data["traceEvents"] + assert len(events) > 0 + kernel_events = [e for e in events if e["name"] == "kernel"] + assert len(kernel_events) > 0 + + +@pytest.mark.parametrize("num_ctas", [1, 2]) +def test_tensor_descriptor(num_ctas, tmp_path: pathlib.Path): + if num_ctas == 2 and (not is_cuda() or torch.cuda.get_device_capability(0)[0] not in (9, 10)): + pytest.skip("CTAs is unsupported for these cards") + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + assert desc.shape[0] == M + assert desc.shape[1] == N + assert desc.strides[0] == N + assert desc.strides[1] == 1 + assert desc.block_shape == [M_BLOCK, N_BLOCK] + pl.enter_scope("load_block") + block = desc.load([M_BLOCK, 2 * N_BLOCK]) + pl.exit_scope("load_block") + idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :] + tl.store(out_ptr + idx, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 * num_ctas + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + M_BLOCK = 4 + N_BLOCK = 4 + M, N = M_BLOCK * 3, N_BLOCK * 4 + inp = torch.randn((M, N), device="cuda", dtype=torch.float32) + out = inp.new_empty((M_BLOCK, N_BLOCK)) + + temp_file = tmp_path / "test_tensor_descriptor.chrome_trace" + proton.start(str(temp_file.with_suffix("")), backend="instrumentation", data="trace") + + kernel[(1, )](out, inp, M, N, M_BLOCK, N_BLOCK, num_ctas=num_ctas) + expect = inp[1 * M_BLOCK:2 * M_BLOCK, 2 * N_BLOCK:3 * N_BLOCK] + torch.testing.assert_close(expect, out) + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + trace_events = data["traceEvents"] + if num_ctas == 1: + assert len(trace_events) == 4 + num_cta0_events = sum(1 for e in trace_events if "CTA0" in e["pid"]) + assert num_cta0_events == 4 + else: + assert len(trace_events) == 8 + num_cta0_events = sum(1 for e in trace_events if "CTA0" in e["pid"]) + num_cta1_events = sum(1 for e in trace_events if "CTA1" in e["pid"]) + assert num_cta0_events == 4 + assert num_cta1_events == 4 diff --git a/third_party/mthreads/proton/test/test_lib.py b/third_party/mthreads/proton/test/test_lib.py new file mode 100644 index 0000000000..04251d85ba --- /dev/null +++ b/third_party/mthreads/proton/test/test_lib.py @@ -0,0 +1,95 @@ +""" +Test module for proton's CPP API functionality. +No GPU kernel should be declared in this test. +Python API correctness tests involving GPU kernels should be placed in `test_api.py`. +Profile correctness tests involving GPU kernels should be placed in `test_profile.py`. +""" +import pathlib +import pytest + +import triton._C.libproton.proton as libproton +from triton.profiler.profile import _select_backend + + +def test_record(): + id0 = libproton.record_scope() + id1 = libproton.record_scope() + assert id1 == id0 + 1 + + +def test_state(): + libproton.enter_state("zero") + libproton.exit_state() + + +def test_scope(): + id0 = libproton.record_scope() + libproton.enter_scope(id0, "zero") + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.exit_scope(id1, "one") + libproton.exit_scope(id0, "zero") + + +def test_op(): + id0 = libproton.record_scope() + libproton.enter_op(id0, "zero") + libproton.exit_op(id0, "zero") + + +@pytest.mark.parametrize("source", ["shadow", "python"]) +def test_context(source: str, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_context.hatchet" + session_id = libproton.start(str(temp_file.with_suffix("")), source, "tree", _select_backend()) + depth = libproton.get_context_depth(session_id) + libproton.finalize(session_id, "hatchet") + assert depth >= 0 + assert temp_file.exists() + + +def test_session(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_session.hatchet" + session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + libproton.deactivate(session_id, False) + libproton.activate(session_id) + libproton.finalize(session_id, "hatchet") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_add_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_add_metrics.hatchet" + libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) + libproton.exit_scope(id1, "one") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_init_function_metadata(tmp_path: pathlib.Path): + metadata_file = tmp_path / "meta.json" + metadata_file.write_text("{}") + libproton.init_function_metadata( + 0, + "dummy_fn", + [(0, "root")], + [], + str(metadata_file), + ) + + +def test_instrumented_op_entry_exit(): + libproton.enter_instrumented_op(0, 0, 0, 0) + libproton.exit_instrumented_op(0, 0, 0, 0) + + +def test_set_metric_kernels(): + libproton.set_metric_kernels(0, 0, 0) + + +def test_tensor_metric_construction(): + metric = libproton.TensorMetric(123, libproton.metric_double_index) + assert metric.ptr == 123 + assert metric.index == libproton.metric_double_index diff --git a/third_party/mthreads/proton/test/test_override.py b/third_party/mthreads/proton/test/test_override.py new file mode 100644 index 0000000000..53272aae0d --- /dev/null +++ b/third_party/mthreads/proton/test/test_override.py @@ -0,0 +1,101 @@ +import os +import subprocess +import pathlib +import json +import pytest + +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2 + +pytestmark = pytest.mark.skipif(is_hip_cdna2(), reason="old AMD GPUs are not supported") + + +def test_override(tmp_path: pathlib.Path): + dir_path = os.path.dirname(os.path.realpath(__file__)) + + # Run once to get the file dumps + first_env = os.environ.copy() + first_env["TRITON_ALWAYS_COMPILE"] = "1" + first_env["TRITON_KERNEL_DUMP"] = "1" + first_env["TRITON_DUMP_DIR"] = str(tmp_path) + + subprocess.run(["python3", dir_path + "/override_helper.py", str(tmp_path)], env=first_env) + + ttir_files = list(tmp_path.rglob("*.ttir")) + ttgir_files = list(tmp_path.rglob("*.ttgir")) + llir_files = list(tmp_path.rglob("*.llir")) + + assert len(ttir_files) == 1 + assert len(ttgir_files) == 1 + assert len(llir_files) == 1 + + os.remove(ttir_files[0]) + os.remove(llir_files[0]) + + if is_cuda(): + ptx_files = list(tmp_path.rglob("*.ptx")) + cubin_files = list(tmp_path.rglob("*.cubin")) + assert len(ptx_files) == 1 + assert len(cubin_files) == 1 + os.remove(ptx_files[0]) + os.remove(cubin_files[0]) + + if is_hip(): + gcn_files = list(tmp_path.rglob("*.amdgcn")) + hsaco_files = list(tmp_path.rglob("*.hsaco")) + assert len(hsaco_files) == 1 + assert len(gcn_files) == 1 + os.remove(gcn_files[0]) + os.remove(hsaco_files[0]) + + filename = str(list(tmp_path.rglob("*.ttgir"))[0]) + + with open(filename, "r") as infile: + file_str = infile.readlines() + + # Add ttgir instrumentation + isFirstLoad = True + with open(filename, "w") as outfile: + for line in file_str: + if "tt.get_program_id x" in line: + #insert before the line + line = ' proton.record start "kernel" loc(#loc)\n' + line + elif "arith.cmpi slt" in line: + #insert after the line + line = line + ' proton.record start "load_ops" loc(#loc)\n' + line = line + ' proton.record start "load_x" loc(#loc)\n' + elif ("tt.load" in line and isFirstLoad) or ("amdg.buffer_load" in line and isFirstLoad): + #insert after the line + line = line + ' proton.record end "load_x" loc(#loc)\n' + line = line + ' proton.record start "load_y" loc(#loc)\n' + isFirstLoad = False + elif ("tt.load" in line and not isFirstLoad) or ("amdg.buffer_load" in line and not isFirstLoad): + #insert after the line + line = line + ' proton.record end "load_y" loc(#loc)\n' + line = line + ' proton.record end "load_ops" loc(#loc)\n' + elif "tt.return" in line: + #insert before the line + line = ' proton.record end "kernel" loc(#loc)\n' + line + outfile.write(line) + + # # Run again with kernel override + second_env = os.environ.copy() + second_env["TRITON_ALWAYS_COMPILE"] = "1" + second_env["TRITON_KERNEL_OVERRIDE"] = "1" + second_env["TRITON_OVERRIDE_DIR"] = str(tmp_path) + subprocess.run(["python3", dir_path + "/override_helper.py", str(tmp_path)], env=second_env) + + temp_file = tmp_path / "test_override.hatchet" + + with open(temp_file, "rb") as f: + data = json.load(f) + kernel_frame = data[0]["children"][0]["children"][0] + load_ops = kernel_frame["children"][0] + assert "load_ops" in load_ops["frame"]["name"] + assert ("load_x" in load_ops["children"][0]["frame"]["name"] + or "load_x" in load_ops["children"][1]["frame"]["name"]) + assert ("load_y" in load_ops["children"][0]["frame"]["name"] + or "load_y" in load_ops["children"][1]["frame"]["name"]) + assert load_ops["children"][0]["metrics"]["cycles"] > 0 + assert load_ops["children"][0]["metrics"]["normalized_cycles"] > 0 + assert load_ops["children"][1]["metrics"]["cycles"] > 0 + assert load_ops["children"][1]["metrics"]["normalized_cycles"] > 0 diff --git a/third_party/mthreads/proton/test/test_profile.py b/third_party/mthreads/proton/test/test_profile.py new file mode 100644 index 0000000000..28a293036d --- /dev/null +++ b/third_party/mthreads/proton/test/test_profile.py @@ -0,0 +1,1187 @@ +""" +Reproducibility tests for Proton. +Each test should invoke one or more GPU kernels and check the validity of their profiling results. +""" + +import torch +import triton +import triton.profiler as proton +import json +import pytest +from typing import NamedTuple +import pathlib +import threading + +import triton.language as tl +from triton.profiler.hooks.launch import COMPUTE_METADATA_SCOPE_NAME +import triton.profiler.hooks.launch as proton_launch +import triton.profiler.viewer as viewer +from triton._internal_testing import is_hip, is_blackwell + + +@pytest.mark.parametrize("context", ["shadow", "python"]) +def test_torch(context, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_torch.hatchet" + proton.start(str(temp_file.with_suffix("")), context=context) + proton.enter_scope("test") + torch.ones((2, 2), device="cuda") + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + if context == "shadow": + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + elif context == "python": + assert len(data[0]["children"]) == 1 + # bfs search until find the "elementwise_kernel" and then check its children + queue = [data[0]] + import re + while len(queue) > 0: + parent_frame = queue.pop(0) + for child in parent_frame["children"]: + if "elementwise_kernel" in child["frame"]["name"]: + assert len(child["children"]) == 0 + # check the regex of the parent name matches + # file_name:line_number@function_name + regex = r".+:\d+@.+" + assert re.match(regex, parent_frame["frame"]["name"]) + return + queue.append(child) + + +def test_triton(tmp_path: pathlib.Path): + + @triton.jit + def foo(x, y): + tl.store(y, tl.load(x)) + + x = torch.tensor([2], device="cuda") + y = torch.zeros_like(x) + temp_file = tmp_path / "test_triton.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0"): + with proton.scope("test1"): + foo[(1, )](x, y) + with proton.scope("test2"): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 2 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert len(data[0]["children"][0]["children"]) == 1 + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" + assert data[0]["children"][1]["frame"]["name"] == "test2" + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not reliably attribute cudagraph replay launches to scopes") +def test_cudagraph(tmp_path: pathlib.Path): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + return {"name": "foo_test"} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(): + a = torch.ones((2, 2), device="cuda") + b = torch.ones((2, 2), device="cuda") + c = a + b + foo[(1, )](a, b, c) + + temp_file = tmp_path / "test_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow") + + # warmup + # four kernels + fn() + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(10): + with proton.scope(f"iter_{i}"): + fn() + + with proton.scope("test0"): + g.replay() + + with proton.scope("test1"): + g.replay() + + g.reset() + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test0_frame = None + test1_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test0": + test0_frame = child + if child["frame"]["name"] == "test1": + test1_frame = child + assert test0_frame is not None + assert test1_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test0_frame["children"]) >= 2 + assert test0_frame["children"][0]["metrics"]["time (ns)"] > 0 + else: + # cuda backend supports "" annotation + for test_frame in [test0_frame, test1_frame]: + child = test_frame["children"][0] + assert child["frame"]["name"] == "" + # 0...9 iterations + assert len(child["children"]) == 10 + # check all iterations + for i in range(10): + assert child["children"][i]["frame"]["name"] == f"iter_{i}" + assert child["children"][i]["children"][0]["metrics"]["time (ns)"] > 0 + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not support cudagraph deactivation") +def test_cudagraph_deactivate(tmp_path): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + @triton.jit + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(session): + with proton.scope("scope_a"): + a = torch.ones((2, 2), device="cuda") + proton.deactivate(session) + with proton.scope("scope_b"): + b = torch.ones((2, 2), device="cuda") + proton.activate(session) + with proton.scope("scope_c"): + c = a + b + foo[(1, )](a, b, c) + + temp_file = tmp_path / "test_cudagraph_deactivate.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow", hook="triton") + + # warmup + fn(session) + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(10): + with proton.scope(f"iter_{i}"): + fn(session) + + with proton.scope("test0"): + g.replay() + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + # scope a and c should be recorded, b should be skipped + children = data[0]["children"] + test0_frame = None + for child in children: + if child["frame"]["name"] == "test0": + test0_frame = child + break + assert test0_frame is not None + iter_frame = test0_frame["children"][0]["children"][0] + scope_a_frame = None + scope_b_frame = None + scope_c_frame = None + for child in iter_frame["children"]: + if child["frame"]["name"] == "scope_a": + scope_a_frame = child + if child["frame"]["name"] == "scope_b": + scope_b_frame = child + if child["frame"]["name"] == "scope_c": + scope_c_frame = child + assert scope_a_frame is not None + assert scope_b_frame is None + assert scope_c_frame is not None + + +def test_metrics(tmp_path: pathlib.Path): + + @triton.jit + def foo(x, y): + tl.store(y, tl.load(x)) + + x = torch.tensor([2], device="cuda") + y = torch.zeros_like(x) + temp_file = tmp_path / "test_metrics.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + + +def test_scope_backward(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_backward.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("ones1"): + a = torch.ones((100, 100), device="cuda", requires_grad=True) + with proton.scope("plus"): + a2 = a * a * a + with proton.scope("ones2"): + loss = torch.ones_like(a2) + + # Backward triggers two kernels in a single scope + with proton.scope("backward"): + a2.backward(loss) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 4 + + +def test_cpu_timed_scope(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_cpu_timed_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.cpu_timed_scope("test0"): + with proton.cpu_timed_scope("test1"): + torch.ones((100, 100), device="cuda") + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + test0_frame = data[0]["children"][0] + assert test0_frame["metrics"]["cpu_time (ns)"] > 0 + test1_frame = test0_frame["children"][0] + assert test1_frame["metrics"]["cpu_time (ns)"] > 0 + + +def test_get_data(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_tree_json.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + with proton.scope("test"): + x = torch.ones((2, 2), device="cuda") + foo[(1, )](x, x, 4) + foo[(1, )](x, x, 4) + + proton.deactivate(session, flushing=True) + + database = proton.data.get(session) + gf, _, _, _ = viewer.get_raw_metrics(database) + foo_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + ones_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*elementwise.*' AND c IS LEAF").dataframe + + assert len(foo_frame) == 1 + assert int(foo_frame["count"].values[0]) == 2 + assert len(ones_frame) == 1 + assert int(ones_frame["count"].values[0]) == 1 + + import msgpack + msgpack_data = proton.data.get_msgpack(session) + database_unpacked = msgpack.loads(msgpack_data, raw=False, strict_map_key=False) + assert database == database_unpacked + + proton.finalize() + + +def test_clear_data(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_clear_data.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow") + + with proton.scope("test0"): + x = torch.ones((2, 2), device="cuda") + x + x # type: ignore + + proton.deactivate(session, flushing=True) + proton.data.clear(session) + try: + database = proton.data.get(session) + except RuntimeError as e: + assert "has no data" in str(e) + + proton.activate(session) + with proton.scope("test1"): + x * x # type: ignore + proton.deactivate(session, flushing=True) + database = proton.data.get(session) + + proton.finalize() + assert len(database[0]["children"]) == 1 + assert database[0]["children"][0]["frame"]["name"] == "test1" + kernel_frame = database[0]["children"][0]["children"][0] + assert "elementwise" in kernel_frame["frame"]["name"] + + +def test_clear_data_up_to_phase(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_clear_data_up_to_phase.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow") + + with proton.scope("phase0"): + x = torch.ones((2, 2), device="cuda") + x + x # type: ignore + + phase1 = proton.data.advance_phase(session) + with proton.scope("phase1"): + x = torch.ones((2, 2), device="cuda") + x + x # type: ignore + + proton.deactivate(session, flushing=True) + + # Clear a range of phases. + proton.data.clear(session, phase=phase1, clear_up_to_phase=True) + database = proton.data.get(session, phase=phase1) + assert len(database[0]["children"]) == 0 + + proton.finalize() + + +def test_data_is_phase_complete(tmp_path: pathlib.Path): + temp_path = tmp_path / "test_data_is_phase_complete.hatchet" + session = proton.start(str(temp_path.with_suffix("")), context="shadow") + + def fn(): + with proton.scope("test0"): + x = torch.ones((2, 2), device="cuda") + x + x # type: ignore + + fn() + assert not proton.data.is_phase_complete(session, 0) + + proton.deactivate(session) + # likely the GPU has not completed the data yet + assert not proton.data.is_phase_complete(session, 0) + + proton.activate(session) + phase = proton.data.advance_phase(session) + fn() + proton.deactivate(session, flushing=True) + # session 0 is a previous phase but we have called deactivate with flushing + assert proton.data.is_phase_complete(session, 0) + # phase 1 is the current phase so cannot be a completed phase + assert not proton.data.is_phase_complete(session, phase) + proton.data.advance_phase(session) + # phase 0 should remain completed after advancing phases + assert proton.data.is_phase_complete(session, phase - 1) + + proton.finalize() + + +def test_hook_launch(tmp_path: pathlib.Path): + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + # get arg's element size + element_size = args["x"].element_size() # non-const + size = args["size"] # const + key = "flops" + str(element_size * 8) + num_ctas = metadata.num_ctas + # Return an extra metric key beyond the historical flops/bytes allowlist. + return {"name": f"foo_test_{num_ctas}ctas_{size}elems", key: 1.0, "extra_metric": 7.0} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook_triton.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" + assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["extra_metric"] == 7.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_hook_launch_filter(tmp_path: pathlib.Path): + + foo_metadata_invoked = False + bar_metadata_invoked = False + + def foo_metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + nonlocal foo_metadata_invoked + foo_metadata_invoked = True + return {"name": "foo_meta", "flops": 1.0} + + def bar_metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + nonlocal bar_metadata_invoked + bar_metadata_invoked = True + return {"name": "bar_meta", "flops": 2.0} + + @triton.jit(launch_metadata=foo_metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + @triton.jit(launch_metadata=bar_metadata_fn) + def bar(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook_triton_filter.hatchet" + + # Only allow kernels whose compiled name matches "foo" (via prefix regex). + launch_hook = proton_launch.LaunchHook() + launch_hook.configure(include=".*foo") + proton.start(str(temp_file.with_suffix("")), hook=launch_hook) + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + bar[(1, )](x, 1, y, num_warps=4) + proton.finalize() + # Reset singleton hook state to avoid leaking filter settings across tests. + launch_hook.configure(include=None, exclude=None) + + assert foo_metadata_invoked is True + assert bar_metadata_invoked is False + + with temp_file.open() as f: + data = json.load(f) + + # Ensure the "foo_meta" override exists and "bar_meta" does not. + all_names = set() + queue = [data[0]] + while queue: + node = queue.pop() + if "frame" in node and "name" in node["frame"]: + all_names.add(node["frame"]["name"]) + queue.extend(node.get("children", [])) + + assert "foo_meta" in all_names + assert "bar_meta" not in all_names + + +@pytest.mark.parametrize("context", ["shadow", "python"]) +def test_hook_launch_context(tmp_path: pathlib.Path, context: str): + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + # A gpu kernel, but it should be under the metadata state + return {"name": "foo_test", "bytes": x.sum().item()} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", context=context) + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + # bfs search until find the reduce kernel and then check its parent + queue = [data[0]] + while len(queue) > 0: + parent_frame = queue.pop(0) + for child in parent_frame["children"]: + if "reduce" in child["frame"]["name"]: + assert parent_frame["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME + return + queue.append(child) + + +def test_hook_with_third_party(tmp_path: pathlib.Path): + third_party_hook_invoked = False + + def third_party_hook(metadata) -> None: + nonlocal third_party_hook_invoked + third_party_hook_invoked = True + + triton.knobs.runtime.launch_enter_hook.add(third_party_hook) + + proton_hook_invoked = False + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + nonlocal proton_hook_invoked + proton_hook_invoked = True + return {"name": "foo_test"} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook_with_third_party.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + triton.knobs.runtime.launch_enter_hook.remove(third_party_hook) + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "foo_test" + assert data[0]["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_hook_multiple_threads(tmp_path: pathlib.Path): + + def metadata_fn_foo(grid: tuple, metadata: NamedTuple, args: dict): + return {"name": "foo_test"} + + @triton.jit(launch_metadata=metadata_fn_foo) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + def metadata_fn_bar(grid: tuple, metadata: NamedTuple, args: dict): + return {"name": "bar_test"} + + @triton.jit(launch_metadata=metadata_fn_bar) + def bar(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x_foo = torch.tensor([2], device="cuda", dtype=torch.float32) + y_foo = torch.zeros_like(x_foo) + x_bar = torch.tensor([2], device="cuda", dtype=torch.float32) + y_bar = torch.zeros_like(x_bar) + + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + + all_ids = set() + + # start multiple threads + def invoke_foo(): + for _ in range(100): + foo[(1, )](x_foo, 1, y_foo, num_warps=4) + all_ids.add(proton_launch.id.get()) + + def invoke_bar(): + for _ in range(100): + bar[(1, )](x_bar, 1, y_bar, num_warps=4) + all_ids.add(proton_launch.id.get()) + + thread_foo = threading.Thread(target=invoke_foo) + thread_bar = threading.Thread(target=invoke_bar) + thread_foo.start() + thread_bar.start() + thread_foo.join() + thread_bar.join() + + proton.finalize() + assert len(all_ids) == 200 + + with temp_file.open() as f: + data = json.load(f) + root = data[0]["children"] + assert "foo_test" in root[0]["frame"]["name"] or root[1]["frame"]["name"] + assert "bar_test" in root[0]["frame"]["name"] or root[1]["frame"]["name"] + assert root[0]["metrics"]["count"] == 100 + assert root[1]["metrics"]["count"] == 100 + + +def test_pcsampling(tmp_path: pathlib.Path): + if is_hip(): + pytest.skip("HIP backend does not support pc sampling") + + import os + + if os.environ.get("PROTON_SKIP_PC_SAMPLING_TEST", "0") == "1": + pytest.skip("PC sampling test is disabled") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + for _ in range(1000): + tl.store(y + offs, tl.load(x + offs)) + + temp_file = tmp_path / "test_pcsampling.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti", mode="pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_deactivate.hatchet" + session_id = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"] + + +def test_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_multiple_sessions0.hatchet" + temp_file1 = tmp_path / "test_multiple_sessions1.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + with proton.scope("scope0"): + torch.randn((10, 10), device="cuda") + torch.randn((10, 10), device="cuda") + proton.deactivate(session_id0) + proton.finalize(session_id0) + with proton.scope("scope1"): + torch.randn((10, 10), device="cuda") + proton.finalize(session_id1) + # kernel has been invoked twice in session 0 and three times in session 1 + with temp_file0.open() as f: + data = json.load(f) + assert data[0]["children"][0]["frame"]["name"] == "scope0" + assert int(data[0]["children"][0]["children"][0]["metrics"]["count"]) == 2 + with temp_file1.open() as f: + data = json.load(f) + scope0_count = int(data[0]["children"][0]["children"][0]["metrics"]["count"]) + scope1_count = int(data[0]["children"][1]["children"][0]["metrics"]["count"]) + assert scope0_count + scope1_count == 3 + + +def test_trace(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_trace.chrome_trace" + proton.start(str(temp_file.with_suffix("")), data="trace") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + trace_events = data["traceEvents"] + assert len(trace_events) == 3 + assert trace_events[-1]["name"] == "foo" + assert trace_events[-1]["args"]["call_stack"] == ["ROOT", "test", "foo"] + + +def test_scope_multiple_threads(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_threads.hatchet" + proton.start(str(temp_file.with_suffix(""))) + + N = 50 + thread_names = ["threadA", "threadB"] + + def worker(prefix: str): + for i in range(N): + name = f"{prefix}_{i}" + proton.enter_scope(name) + torch.ones((1, ), device="cuda") + proton.exit_scope() + + threads = [threading.Thread(target=worker, args=(tname, )) for tname in thread_names] + for t in threads: + t.start() + for t in threads: + t.join() + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + assert len(children) == N * len(thread_names) + names = {c["frame"]["name"] for c in children} + expected = {f"{t}_{i}" for t in thread_names for i in range(N)} + assert names == expected + + +@pytest.mark.parametrize("enable_nvtx", [None, True, False]) +def test_nvtx_range_push_pop(enable_nvtx, fresh_knobs, tmp_path: pathlib.Path): + if enable_nvtx is not None: + fresh_knobs.proton.enable_nvtx = enable_nvtx + temp_file = tmp_path / "test_nvtx_range_push_pop.hatchet" + proton.start(str(temp_file.with_suffix(""))) + + with proton.scope("proton_scope"): + torch.cuda.nvtx.range_push("nvtx_range0") + torch.cuda.nvtx.range_push("nvtx_range1") + torch.ones((1, ), device="cuda") + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + assert len(children) == 1 + proton_scope = children[0] + assert proton_scope["frame"]["name"] == "proton_scope" + assert len(proton_scope["children"]) == 1 + if enable_nvtx or enable_nvtx is None: + nvtx_range0 = proton_scope["children"][0] + assert nvtx_range0["frame"]["name"] == "nvtx_range0" + assert len(nvtx_range0["children"]) == 1 + nvtx_range1 = nvtx_range0["children"][0] + assert nvtx_range1["frame"]["name"] == "nvtx_range1" + assert len(nvtx_range1["children"]) == 1 + kernel = nvtx_range1["children"][0] + else: + kernel = proton_scope["children"][0] + assert "elementwise" in kernel["frame"]["name"] + assert kernel["metrics"]["count"] == 1 + + +def test_tensor_metrics_scope(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_tensor_metrics_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + + x = torch.ones((10, 10), device="cuda", dtype=torch.float32) + x_mean = x.mean() + x_std = x.std() + with proton.scope("test", metrics={"x_mean": x_mean, "x_std": x_std}): + torch.randn((10, 10), device="cuda") + torch.zeros_like(x) + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + assert len(children) == 4 + # get the test frame + test_frame = None + for child in children: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + assert test_frame["metrics"]["x_mean"] == 1.0 + assert test_frame["metrics"]["x_std"] == 0.0 + + +def test_tensor_metrics_hook(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_tensor_metrics_hook.hatchet" + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + metric_value = torch.tensor(8.0, device="cuda") + return {"name": "foo_test", "flops": metric_value} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.ones((8, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + + proton.start(str(temp_file.with_suffix("")), hook="triton") + foo[(1, )](x, x.numel(), y, num_warps=4) + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + # metadata scope + foo_test + assert len(children) == 2 + foo_test_frame = None + for child in children: + if child["frame"]["name"] == "foo_test": + foo_test_frame = child + break + assert foo_test_frame is not None + assert foo_test_frame["metrics"]["flops"] == 8.0 + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not support metrics profiling in cudagraphs") +def test_tensor_metrics_cudagraph(tmp_path: pathlib.Path): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + x_sum = x.sum() + return {"name": "foo_test", "bytes": x.numel() * x.element_size(), "flops": x_sum} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(): + with proton.scope("scope_a", metrics={"bytes": 4 * 4}): + a = torch.ones((2, 2), device="cuda") + with proton.metadata_state(): + a_sum = a.sum() + with proton.scope("scope_b", metrics={"sum": a_sum}): + b = torch.ones((2, 2), device="cuda") + c = a + b + foo[(1, )](a, b, c) + + temp_file = tmp_path / "test_tensor_metrics_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow", hook="triton") + + # warmup + # four kernels + fn() + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() + + with proton.scope("test0"): + g.replay() + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + # metadata scope + kernels + scope_a + scope_b + test0 + assert len(children) == 7 + test0_frame = None + for child in children: + if child["frame"]["name"] == "test0": + test0_frame = child + break + assert test0_frame is not None + capture_at_frame = test0_frame["children"][0] + + foo_test_frame = None + scope_a_frame = None + scope_b_frame = None + for child in capture_at_frame["children"]: + if child["frame"]["name"] == "foo_test": + foo_test_frame = child + if child["frame"]["name"] == "scope_a": + scope_a_frame = child + if child["frame"]["name"] == "scope_b": + scope_b_frame = child + assert foo_test_frame is not None + assert foo_test_frame["metrics"]["bytes"] == 160 + assert foo_test_frame["metrics"]["flops"] == 40 + assert foo_test_frame["metrics"]["count"] == 10 + assert scope_a_frame is not None + assert scope_a_frame["metrics"]["bytes"] == 160 + assert "count" not in scope_a_frame["metrics"] + assert scope_b_frame is not None + assert scope_b_frame["metrics"]["sum"] == 40.0 + assert "count" not in scope_b_frame["metrics"] + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not support metrics profiling in cudagraphs") +def test_tensor_metrics_cudagraph_deactivate(tmp_path: pathlib.Path): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + def fn(session): + proton.deactivate(session) + with proton.scope("scope_b", metrics={"sum": 4}): + b = torch.ones((2, 2), device="cuda") + proton.activate(session) + c = b * 2 # noqa: F841 + + temp_file = tmp_path / "test_tensor_metrics_cudagraph_deactivate.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow", hook="triton") + + # warmup + fn(session) + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn(session) + + with proton.scope("test0"): + g.replay() + + proton.finalize() + + # only a single kernel b * 2 + with temp_file.open() as f: + data = json.load(f) + children = data[0]["children"] + test0_frame = None + for child in children: + if child["frame"]["name"] == "test0": + test0_frame = child + break + assert test0_frame is not None + capture_at_frame = test0_frame["children"][0] + scope_b_frame = None + c_frame = None + for child in capture_at_frame["children"]: + if child["frame"]["name"] == "scope_b": + scope_b_frame = child + if "elementwise" in child["frame"]["name"]: + c_frame = child + assert scope_b_frame is None + assert c_frame is not None + assert c_frame["metrics"]["count"] == 10 + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not support metrics profiling in cudagraphs") +def test_tensor_metrics_multi_device_cudagraph(tmp_path: pathlib.Path): + if torch.cuda.device_count() < 2: + pytest.skip("Requires at least two CUDA devices") + + devices = [torch.device(f"cuda:{i}") for i in range(2)] + streams = [] + for device in devices: + with torch.cuda.device(device): + streams.append(torch.cuda.Stream(device=device)) + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + x_sum = x.sum() + device_idx = x.device.index + return {"name": f"foo_test_{device_idx}", "bytes": x.numel() * x.element_size(), "flops": x_sum} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def run_on_device(device_id): + with proton.scope(f"scope_a_{device_id}", metrics={"bytes": 4 * 4}): + a = torch.ones((2, 2), device=f"cuda:{device_id}") + with proton.metadata_state(): + a_sum = a.sum() + with proton.scope(f"scope_b_{device_id}", metrics={"sum": a_sum}): + b = torch.ones((2, 2), device=f"cuda:{device_id}") + c = a + b + foo[(1, )](a, b, c) + + temp_file = tmp_path / "test_tensor_metrics_multi_device_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow", hook="triton") + + graphs = [] + for device, stream in zip(devices, streams): + with torch.cuda.device(device): + torch.cuda.set_stream(stream) + # warmup + run_on_device(device.index) + # graph capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=stream): + for _ in range(10): + run_on_device(device.index) + graphs.append((device, stream, g)) + + for device, stream, graph in graphs: + with torch.cuda.device(device): + torch.cuda.set_stream(stream) + with proton.scope(f"test_device_{device.index}"): + graph.replay() + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + + children = data[0]["children"] + for device in devices: + device_name = f"test_device_{device.index}" + launch_frame = next((child for child in children if child["frame"]["name"] == device_name), None) + assert launch_frame is not None + capture_at_frame = launch_frame["children"][0] + assert capture_at_frame["frame"]["name"] == "" + + foo_frame = None + scope_a_frame = None + scope_b_frame = None + for child in capture_at_frame["children"]: + if child["frame"]["name"] == f"foo_test_{device.index}": + foo_frame = child + if child["frame"]["name"] == f"scope_a_{device.index}": + scope_a_frame = child + if child["frame"]["name"] == f"scope_b_{device.index}": + scope_b_frame = child + + assert foo_frame is not None + assert scope_a_frame is not None + assert scope_b_frame is not None + assert foo_frame["metrics"]["bytes"] == 160 + assert foo_frame["metrics"]["flops"] == 40 + assert foo_frame["metrics"]["device_id"] == str(device.index) + assert scope_a_frame["metrics"]["bytes"] == 160 + assert scope_b_frame["metrics"]["sum"] == 40.0 + + assert len(data) > 1 + cuda_devices = data[1].get("CUDA", {}) + assert len(cuda_devices) >= 2 + + +@pytest.mark.parametrize("buffer_size", [256 * 1024, 64 * 1024 * 1024]) +@pytest.mark.parametrize("data_format", ["hatchet_msgpack", "hatchet"]) +def test_periodic_flushing(tmp_path, fresh_knobs, data_format, buffer_size): + fresh_knobs.proton.profile_buffer_size = buffer_size + temp_file = tmp_path / f"test_periodic_flushing.{data_format}" + session = proton.start(str(temp_file.with_suffix("")), mode=f"periodic_flushing:format={data_format}") + + for i in range(10000): + if i != 0 and i % 1000 == 0: + proton.data.advance_phase(session=session) + with proton.scope(f"test_{i}", metrics={"count": 1}): + torch.zeros((100), device="cuda") + + proton.finalize(output_format=data_format) + + # Find all *.hatchet files under the directory `tmp_path` + import glob + import msgpack + hatchet_files = glob.glob(str(tmp_path / f"*.{data_format}")) + assert len(hatchet_files) == 10 + num_scopes = 0 + for hatchet_file in hatchet_files: + if data_format == "hatchet_msgpack": + with open(hatchet_file, "rb") as f: + data = msgpack.load(f, raw=False, strict_map_key=False) + else: + with open(hatchet_file, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data[0]["children"]) == 1000 + assert data[0]["children"][0]["metrics"]["count"] == 1 + assert data[0]["children"][0]["frame"]["name"].startswith("test_") + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + num_scopes += len(data[0]["children"]) + assert num_scopes == 10000 + + +@pytest.mark.skipif(is_hip(), reason="HIP backend does not support metrics profiling in cudagraphs") +@pytest.mark.parametrize("buffer_size", [256 * 1024, 64 * 1024 * 1024]) +@pytest.mark.parametrize("data_format", ["hatchet_msgpack", "hatchet"]) +def test_periodic_flushing_cudagraph(tmp_path, fresh_knobs, data_format, buffer_size): + fresh_knobs.proton.profile_buffer_size = buffer_size + temp_file = tmp_path / f"test_periodic_flushing.{data_format}" + session = proton.start(str(temp_file.with_suffix("")), mode=f"periodic_flushing:format={data_format}", + hook="triton") + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + x_sum = x.sum() + return {"name": "foo_test", "bytes": x.numel() * x.element_size(), "flops": x_sum} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(): + with proton.scope("scope_a", metrics={"bytes": 4 * 4}): + a = torch.ones((2, 2), device="cuda") + c = a + a + foo[(1, )](a, a, c) + + # warmup + fn() + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + + with proton.scope("test0"): + for i in range(10000): + if i != 0 and i % 1000 == 0: + proton.data.advance_phase(session=session) + g.replay() + + proton.finalize(output_format=data_format) + + # Find all *.hatchet files under the directory `tmp_path` + import glob + import msgpack + hatchet_files = glob.glob(str(tmp_path / f"*.{data_format}")) + assert len(hatchet_files) == 10 + for hatchet_file in hatchet_files: + if data_format == "hatchet_msgpack": + with open(hatchet_file, "rb") as f: + data = msgpack.load(f, raw=False, strict_map_key=False) + else: + with open(hatchet_file, "r", encoding="utf-8") as f: + data = json.load(f) + capture_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test0": + capture_frame = child["children"][0] + break + assert capture_frame is not None + scope_a_frame = None + foo_test_frame = None + for child in capture_frame["children"]: + if child["frame"]["name"] == "scope_a": + scope_a_frame = child + if child["frame"]["name"] == "foo_test": + foo_test_frame = child + assert scope_a_frame is not None + assert foo_test_frame is not None + assert scope_a_frame["metrics"]["bytes"] == 16000 + assert foo_test_frame["metrics"]["bytes"] == 16000 + assert foo_test_frame["metrics"]["flops"] == 4000 + + +@pytest.mark.skipif(not is_blackwell(), reason="HW trace is only supported on Blackwell GPUs") +def test_hw_trace(fresh_knobs, tmp_path: pathlib.Path): + fresh_knobs.proton.enable_hw_trace = True + temp_file = tmp_path / "test_hw_trace.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) # noqa: F841 + + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + kernel_frame = data[0]["children"][0]["children"][0] + assert "elementwise" in kernel_frame["frame"]["name"] + assert kernel_frame["metrics"]["time (ns)"] > 0 diff --git a/third_party/mthreads/proton/test/test_viewer.py b/third_party/mthreads/proton/test/test_viewer.py new file mode 100644 index 0000000000..64c8524175 --- /dev/null +++ b/third_party/mthreads/proton/test/test_viewer.py @@ -0,0 +1,199 @@ +import pytest +import subprocess +from triton.profiler.viewer import get_min_time_flops, get_min_time_bytes, read, format_frames, derive_metrics, filter_frames, parse, apply_diff_profile +from triton.profiler.hooks.launch import COMPUTE_METADATA_SCOPE_NAME +import numpy as np + +file_path = __file__ +triton_example_file = file_path.replace("test_viewer.py", "examples/triton.json") +cuda_example_file = file_path.replace("test_viewer.py", "examples/cuda.json") +hip_example_file = file_path.replace("test_viewer.py", "examples/hip.json") +frame_example_file = file_path.replace("test_viewer.py", "examples/frame.json") +leaf_example_file = file_path.replace("test_viewer.py", "examples/leaf_nodes.json") + + +def test_help(): + # Only check if the viewer can be invoked + subprocess.check_call(["proton-viewer", "-h"], stdout=subprocess.DEVNULL) + + +def test_exclusive_metrics(): + gf, inclusive_metrics, exclusive_metrics, device_info = read(triton_example_file) + metrics = ["cpu_time/ns"] + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:1]["name"].values[0] + assert actual == "scope" + + +def test_sort(): + gf, inclusive_metrics, exclusive_metrics, device_info = read(leaf_example_file) + gf = format_frames(gf, None) + metrics = ["time/s", "time/ms", "time/us", "time/ns"] + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:5]["name"].values + expected = ["ROOT", "kernel_1_1_1", "kernel_3_1_1", "kernel_3_2_2", "kernel_1_2_2"] + assert len(actual) == len(expected) + assert all(a == b for a, b in zip(actual, expected)) + + +@pytest.mark.parametrize("option", ["full", "file_function_line", "function_line", "file_function"]) +def test_format_frames(option): + gf, _, _, _ = read(frame_example_file) + gf = format_frames(gf, option) + if option == "full": + idx = gf.dataframe["name"] == "/home/user/projects/example.py/test.py:1@foo" + elif option == "file_function_line": + idx = gf.dataframe["name"] == "test.py:1@foo" + elif option == "function_line": + idx = gf.dataframe["name"] == "1@foo" + elif option == "file_function": + idx = gf.dataframe["name"] == "test.py@foo" + assert idx.sum() == 1 + + +@pytest.mark.parametrize("option", ["include", "exclude"]) +def test_filter_frames(option): + include = "" + exclude = "" + gf, _, _, _ = read(frame_example_file) + if option == "include": + include = ".*test0.*" + elif option == "exclude": + exclude = ".*test1.*" + gf = filter_frames(gf, include=include, exclude=exclude) + idx = gf.dataframe["name"] == "test1" + assert idx.sum() == 0 + idx = gf.dataframe["name"] == "test0" + assert idx.sum() == 1 + + +def test_filter_metadata(): + gf, _, _, _ = read(triton_example_file) + assert COMPUTE_METADATA_SCOPE_NAME not in gf.dataframe["name"].tolist() + assert "cuda_kernel" not in gf.dataframe["name"].tolist() + assert "scope" in gf.dataframe["name"].tolist() + assert "triton_kernel" in gf.dataframe["name"].tolist() + + +def test_parse(): + gf, derived_metrics = parse(["time/s"], triton_example_file) + for derived_metric in derived_metrics: + assert derived_metric in gf.inc_metrics or derived_metric in gf.exc_metrics + + +def test_min_time_flops(): + gf, _, _, device_info = read(cuda_example_file) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + device2_idx = gf.dataframe["device_id"] == "2" + # sm89 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) + # sm90 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.00005]], atol=1e-5) + # sm100 + np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000025]], atol=1e-5) + gf, _, _, device_info = read(hip_example_file) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + device2_idx = gf.dataframe["device_id"] == "2" + # CDNA2 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000055]], atol=1e-5) + # CDNA3 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.000038]], atol=1e-5) + # CDNA4 + np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000217]], atol=1e-5) + + +def test_min_time_bytes(): + gf, _, _, device_info = read(cuda_example_file) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + # sm89 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) + # sm90 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[2.48584e-05]], atol=1e-6) + gf, _, _, device_info = read(hip_example_file) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + device2_idx = gf.dataframe["device_id"] == "2" + # CDNA2 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[3.125e-06]], atol=1e-6) + # CDNA3 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6) + # CDNA4 + np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000125]], atol=1e-6) + + +def test_percentage(): + pass + + +def derivation_metrics_test(metrics, expected_data, sample_file, rtol=1e-7, atol=1e-6): + gf, inclusive_metrics, exclusive_metrics, device_info = read(sample_file) + assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" + derived_metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + for derived_metric in derived_metrics: + np.testing.assert_allclose(gf.dataframe[derived_metric].to_numpy(), expected_data[derived_metric], rtol=rtol, + atol=atol) + + +def test_avg_time_derivation(): + derivation_metrics_test( + metrics=["avg_time/s", "avg_time/ms", "avg_time/us", "avg_time/ns"], expected_data={ + "avg_time/s (inc)": [0.0000512, 0.0000205, 0.000205, + 0.000205], "avg_time/ms (inc)": [0.0512, 0.02048, 0.2048, 0.2048], "avg_time/us (inc)": + [51.2, 20.48, 204.8, 204.8], "avg_time/ns (inc)": [51200.0, 20480.0, 204800.0, 204800.0] + }, sample_file=cuda_example_file) + + +def test_util(): + derivation_metrics_test(metrics=["util"], expected_data={ + "util": [np.nan, 0.247044, 0.147830, 0.118451], + }, sample_file=cuda_example_file) + + +def test_time_derivation(): + derivation_metrics_test( + metrics=["time/s", "time/ms", "time/us", "time/ns"], expected_data={ + "time/s (inc)": [0.000614, 0.0002048, 0.0002048, 0.0002048], + "time/ms (inc)": [0.6144, 0.2048, 0.2048, 0.2048], + "time/us (inc)": [614.4, 204.8, 204.8, 204.8], + "time/ns (inc)": [614400.0, 204800.0, 204800.0, 204800.0], + "time/% (inc)": [100.0, 50.0, 50.0, 50.0], + }, sample_file=cuda_example_file) + + +def test_bytes_derivation(): + derivation_metrics_test( + metrics=["byte/s", "gbyte/s", "tbyte/s"], expected_data={ + "byte/s (inc)": [1.953125e+11, 4.88281250e+11, 4.88281250e+10, + 4.88281250e+10], "gbyte/s (inc)": [195.3125, 488.28125, 48.828125, 48.828125], + "tbyte/s (inc)": [0.195312, 0.48828125, 0.04882812, 0.04882812] + }, sample_file=cuda_example_file) + + +def test_flops_derivation(): + derivation_metrics_test( + metrics=["flop8/s", "gflop8/s", "tflop8/s"], + expected_data={ + "flop8/s (inc)": [3.417969e+14, 4.88281250e+14, 4.88281250e+13, + 4.88281250e+14], "gflop8/s (inc)": [341796.875, 488281.25, 48828.125, 488281.25], + "tflop8/s (inc)": [341.796875, 488.28125, 48.828125, 488.28125] + }, + sample_file=cuda_example_file, + ) + + +def test_diff_profile(): + gf, derived_metrics = parse(["time/s"], triton_example_file) + gf2, _ = parse(["time/s"], cuda_example_file) + gf = apply_diff_profile(gf, derived_metrics, cuda_example_file, ["time/s"], None, None, 0.0) + assert "time/s (inc)" in gf.dataframe.columns diff --git a/third_party/mthreads/proton/test/unittest/CMakeLists.txt b/third_party/mthreads/proton/test/unittest/CMakeLists.txt new file mode 100644 index 0000000000..5646a3832a --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TraceDataIO) diff --git a/third_party/mthreads/proton/test/unittest/TraceDataIO/ByteSpanTest.cpp b/third_party/mthreads/proton/test/unittest/TraceDataIO/ByteSpanTest.cpp new file mode 100644 index 0000000000..49f03d58c9 --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/TraceDataIO/ByteSpanTest.cpp @@ -0,0 +1,76 @@ +#include "TraceDataIO/ByteSpan.h" +#include +#include +#include + +using namespace proton; + +TEST(ByteSpanTest, ReadAndNavigation) { + std::vector testData = { + // int8 values (positions 0-3) + 0x00, // 0 + 0x7F, // 127 + 0x80, // -128 + 0xFF, // -1 + + // int16 values (positions 4-7) + 0x34, 0x12, // 0x1234 + 0x00, 0x80, // 0x8000 + + // int32 values (positions 8-15) + 0x78, 0x56, 0x34, 0x12, // 0x12345678 + 0x00, 0x00, 0x00, 0x80 // 0x80000000 + }; + + ByteSpan span(testData.data(), testData.size()); + + // Test initial state + EXPECT_EQ(span.position(), 0); + EXPECT_EQ(span.size(), 16); + EXPECT_EQ(span.remaining(), 16); + EXPECT_TRUE(span.hasRemaining(16)); + EXPECT_FALSE(span.hasRemaining(17)); + + // Test 8-bit reading + EXPECT_EQ(span.readInt8(), 0); + EXPECT_EQ(span.readInt8(), 127); + EXPECT_EQ(span.readUInt8(), 128); + EXPECT_EQ(span.readUInt8(), 255); + EXPECT_EQ(span.position(), 4); + + // Test navigation - seeking back + span.seek(1); + EXPECT_EQ(span.position(), 1); + EXPECT_EQ(span.readInt8(), 127); + EXPECT_EQ(span.position(), 2); + + // Test navigation - skipping + span.skip(2); + EXPECT_EQ(span.position(), 4); + + // Test 16-bit reading + EXPECT_EQ(span.readUInt16(), 0x1234); // 0x1234 + EXPECT_EQ(span.readInt16(), -32768); // 0x8000 + EXPECT_EQ(span.position(), 8); + + // Test navigation - seeking to specific position + span.seek(8); + + // Test 32-bit reading + EXPECT_EQ(span.readUInt32(), 305419896); // 0x12345678 + EXPECT_EQ(span.readInt32(), -2147483648); // 0x80000000 + EXPECT_EQ(span.position(), 16); + + // Test navigation - buffer overflow + EXPECT_THROW(span.skip(1), BufferException); + EXPECT_THROW(span.seek(17), BufferException); + + // Test navigation - at the end + EXPECT_EQ(span.remaining(), 0); + EXPECT_FALSE(span.hasRemaining(1)); +} + +int main(int argc, char *argv[]) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/mthreads/proton/test/unittest/TraceDataIO/CMakeLists.txt b/third_party/mthreads/proton/test/unittest/TraceDataIO/CMakeLists.txt new file mode 100644 index 0000000000..76a5d325c0 --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/TraceDataIO/CMakeLists.txt @@ -0,0 +1,15 @@ +set(PROTON_TEST_UTIL_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../util/") +add_compile_definitions(PROTON_TEST_UTIL_PATH="${PROTON_TEST_UTIL_PATH}") + +add_triton_ut( + NAME TraceDataIO + SRCS ByteSpanTest.cpp DecoderTest.cpp CircularLayoutParserTest.cpp ChromeTraceWriterTest.cpp + LIBS ProtonTraceDataIO +) + +target_include_directories(TraceDataIO +PRIVATE + "${JSON_INCLUDE_DIR}" + "${PROTON_COMMON_DIR}/include" + "${PROTON_SRC_DIR}/include" +) diff --git a/third_party/mthreads/proton/test/unittest/TraceDataIO/ChromeTraceWriterTest.cpp b/third_party/mthreads/proton/test/unittest/TraceDataIO/ChromeTraceWriterTest.cpp new file mode 100644 index 0000000000..f6adab1005 --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/TraceDataIO/ChromeTraceWriterTest.cpp @@ -0,0 +1,211 @@ +#include "TraceDataIO/EntryDecoder.h" +#include "TraceDataIO/TraceWriter.h" +#include "nlohmann/json.hpp" +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; +using namespace proton; + +class ChromeTraceWriterTest : public ::testing::Test { +public: + void SetUp() override {} + + void TearDown() override { + try { + std::filesystem::remove_all(chromeTracePath); + } catch (const std::filesystem::filesystem_error &e) { + std::cerr << "Error cleaning up test trace files: " << e.what() + << std::endl; + } + } + + void printJsonTrace(json data) { std::cout << data.dump(4) << std::endl; } + + json readJsonTrace(const std::string &path) { + std::ifstream file(path); + + if (!file.is_open()) { + std::cerr << "Failed to open chrome trace file!" << std::endl; + return json(); + } + + json data; + try { + data = json::parse(file); + } catch (json::parse_error &e) { + std::cerr << "Error parsing JSON: " << e.what() << std::endl; + data = json(); + } + file.close(); + return data; + } + + std::shared_ptr + createDefaultResult(int numBlocks, int numTraces, int numEvents) { + auto result = std::make_shared(); + result->blockTraces.resize(numBlocks); + for (int i = 0; i < numBlocks; i++) { + result->blockTraces[i].traces.resize(numTraces); + for (int j = 0; j < numTraces; j++) { + result->blockTraces[i].traces[j].profileEvents.resize(numEvents); + for (int k = 0; k < numEvents; k++) { + result->blockTraces[i].traces[j].profileEvents[k].first = + std::make_shared(); + result->blockTraces[i].traces[j].profileEvents[k].second = + std::make_shared(); + } + } + } + return result; + } + +protected: + std::string chromeTracePath = "chrome_trace.json"; +}; + +TEST_F(ChromeTraceWriterTest, SingleBlock) { + auto metadata = std::make_shared(); + metadata->kernelName = "kernel1"; + metadata->scopeName = {{1, "s1"}, {2, "s2"}}; + + auto result = createDefaultResult(1, 1, metadata->scopeName.size()); + result->blockTraces[0].blockId = 1; + result->blockTraces[0].procId = 120; + result->blockTraces[0].initTime = 0; + result->blockTraces[0].traces[0].uid = 2; + result->blockTraces[0].traces[0].profileEvents[0].first->cycle = 122; + result->blockTraces[0].traces[0].profileEvents[0].second->cycle = 162; + result->blockTraces[0].traces[0].profileEvents[0].first->scopeId = 1; + result->blockTraces[0].traces[0].profileEvents[0].second->scopeId = 1; + result->blockTraces[0].traces[0].profileEvents[1].first->cycle = 222; + result->blockTraces[0].traces[0].profileEvents[1].second->cycle = 262; + result->blockTraces[0].traces[0].profileEvents[1].first->scopeId = 7; + result->blockTraces[0].traces[0].profileEvents[1].second->scopeId = 7; + std::vector kerneltrace = {std::make_pair(result, metadata)}; + auto writer = StreamChromeTraceWriter(kerneltrace, chromeTracePath); + writer.dump(); + + auto data = readJsonTrace(chromeTracePath); + EXPECT_EQ(data.empty(), false); + EXPECT_EQ(data["displayTimeUnit"], "ns"); + EXPECT_EQ(data["traceEvents"].size(), 2); + EXPECT_EQ(data["traceEvents"][0]["name"], "s1"); + EXPECT_EQ(data["traceEvents"][1]["name"], "scope_7"); + EXPECT_DOUBLE_EQ(data["traceEvents"][0]["ts"], 0.0); + EXPECT_DOUBLE_EQ(data["traceEvents"][1]["ts"], 0.1); +} + +TEST_F(ChromeTraceWriterTest, MultiBlockMultiWarp) { + auto metadata = std::make_shared(); + metadata->kernelName = "kernel2"; + metadata->scopeName = {{1, "s1"}, {2, "s2"}, {3, "s3"}, {4, "s4"}}; + + auto result = createDefaultResult(2, 3, metadata->scopeName.size()); + + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 2; j++) { + result->blockTraces[j].blockId = 1 + j; + result->blockTraces[j].procId = 120 + j; + result->blockTraces[j].traces[i].uid = i; + result->blockTraces[j].traces[i].profileEvents[0].first->cycle = 122; + result->blockTraces[j].traces[i].profileEvents[0].second->cycle = 162; + result->blockTraces[j].traces[i].profileEvents[0].first->scopeId = 1; + result->blockTraces[j].traces[i].profileEvents[0].second->scopeId = 1; + result->blockTraces[j].traces[i].profileEvents[1].first->cycle = 142; + result->blockTraces[j].traces[i].profileEvents[1].second->cycle = 182; + result->blockTraces[j].traces[i].profileEvents[1].first->scopeId = 2; + result->blockTraces[j].traces[i].profileEvents[1].second->scopeId = 2; + result->blockTraces[j].traces[i].profileEvents[2].first->cycle = 172; + result->blockTraces[j].traces[i].profileEvents[2].second->cycle = 200; + result->blockTraces[j].traces[i].profileEvents[2].first->scopeId = 3; + result->blockTraces[j].traces[i].profileEvents[2].second->scopeId = 3; + result->blockTraces[j].traces[i].profileEvents[3].first->cycle = 183; + result->blockTraces[j].traces[i].profileEvents[3].second->cycle = 210; + result->blockTraces[j].traces[i].profileEvents[3].first->scopeId = 4; + result->blockTraces[j].traces[i].profileEvents[3].second->scopeId = 4; + } + } + std::vector kerneltrace = {std::make_pair(result, metadata)}; + auto writer = StreamChromeTraceWriter(kerneltrace, chromeTracePath); + writer.dump(); + + auto data = readJsonTrace(chromeTracePath); + + EXPECT_EQ(data.empty(), false); + EXPECT_EQ(data["traceEvents"].size(), 24); + std::map pidCount; + std::map tidCount; + for (int i = 0; i < 24; i++) { + pidCount[data["traceEvents"][i]["pid"]] += 1; + tidCount[data["traceEvents"][i]["tid"]] += 1; + } + EXPECT_EQ(pidCount["kernel2 Core121 CTA2"], 12); + EXPECT_EQ(pidCount["kernel2 Core120 CTA1"], 12); + EXPECT_EQ(tidCount["warp 0 (line 0)"], 4); + EXPECT_EQ(tidCount["warp 0 (line 1)"], 4); + EXPECT_EQ(tidCount["warp 1 (line 0)"], 4); + EXPECT_EQ(tidCount["warp 1 (line 1)"], 4); + EXPECT_EQ(tidCount["warp 2 (line 0)"], 4); + EXPECT_EQ(tidCount["warp 2 (line 1)"], 4); +} + +TEST_F(ChromeTraceWriterTest, MultiKernel) { + auto metadata1 = std::make_shared(); + metadata1->kernelName = "kernel1"; + metadata1->scopeName = {{1, "s1"}}; + auto result1 = createDefaultResult(1, 2, metadata1->scopeName.size()); + + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 1; j++) { + result1->blockTraces[j].blockId = j; + result1->blockTraces[j].procId = j; + result1->blockTraces[j].initTime = 0; + result1->blockTraces[j].traces[i].uid = i; + result1->blockTraces[j].traces[i].profileEvents[0].first->cycle = 1220000; + result1->blockTraces[j].traces[i].profileEvents[0].second->cycle = + 1620000; + result1->blockTraces[j].traces[i].profileEvents[0].first->scopeId = 1; + result1->blockTraces[j].traces[i].profileEvents[0].second->scopeId = 1; + } + } + + auto metadata2 = std::make_shared(); + metadata2->kernelName = "kernel2"; + metadata2->scopeName = {{1, "s1"}}; + auto result2 = createDefaultResult(2, 1, metadata2->scopeName.size()); + + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + result2->blockTraces[j].blockId = j; + result2->blockTraces[j].procId = j; + result2->blockTraces[j].initTime = 10000000; + result2->blockTraces[j].traces[i].uid = i; + result2->blockTraces[j].traces[i].profileEvents[0].first->cycle = 1220000; + result2->blockTraces[j].traces[i].profileEvents[0].second->cycle = + 1620000; + result2->blockTraces[j].traces[i].profileEvents[0].first->scopeId = 1; + result2->blockTraces[j].traces[i].profileEvents[0].second->scopeId = 1; + } + } + std::vector kerneltrace = {std::make_pair(result1, metadata1), + std::make_pair(result2, metadata2)}; + auto writer = StreamChromeTraceWriter(kerneltrace, chromeTracePath); + writer.dump(); + + auto data = readJsonTrace(chromeTracePath); + + EXPECT_EQ(data.empty(), false); + EXPECT_EQ(data["traceEvents"][0]["cat"], "kernel1"); + EXPECT_DOUBLE_EQ(data["traceEvents"][0]["ts"], 0.0); + EXPECT_DOUBLE_EQ(data["traceEvents"][0]["dur"], 400.0); + EXPECT_EQ(data["traceEvents"][1]["cat"], "kernel1"); + EXPECT_EQ(data["traceEvents"][2]["cat"], "kernel2"); + EXPECT_DOUBLE_EQ(data["traceEvents"][2]["ts"], 10000.0); + EXPECT_DOUBLE_EQ(data["traceEvents"][2]["dur"], 400.0); +} diff --git a/third_party/mthreads/proton/test/unittest/TraceDataIO/CircularLayoutParserTest.cpp b/third_party/mthreads/proton/test/unittest/TraceDataIO/CircularLayoutParserTest.cpp new file mode 100644 index 0000000000..a18102c08d --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/TraceDataIO/CircularLayoutParserTest.cpp @@ -0,0 +1,275 @@ +#include "TraceDataIO/CircularLayoutParser.h" +#include +#include +#include +#include +#include +#include + +using namespace proton; + +class CircularLayoutParserTest : public ::testing::Test { +public: + explicit CircularLayoutParserTest(const std::string &kernel = "") + : kernel(kernel) {} + + void SetUp() override { + if (!kernel.empty()) { + output = PROTON_TEST_UTIL_PATH; + output += "/" + kernel + ".bin"; + } + } + + void TearDown() override {} + + ByteSpan getBuffer(std::string binPath) { + std::ifstream file(binPath, std::ios::binary); + + if (!file) { + std::cerr << "Cannot open file!" << std::endl; + return ByteSpan(nullptr, 0); + } + + // Get file size + file.seekg(0, std::ios::end); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + + testData.resize(size); + + // Read the data + if (!file.read(reinterpret_cast(testData.data()), size)) { + std::cerr << "Error reading file!" << std::endl; + return ByteSpan(nullptr, 0); + } + return ByteSpan(testData.data(), size); + } + +protected: + CircularLayoutParserConfig config; + std::vector testData; + std::string kernel; + std::string output; +}; + +TEST_F(CircularLayoutParserTest, WrongPreamble) { + config.numBlocks = 1; + config.uidVec = {0}; + testData = {0x78, 0x56, 0x34, 0x12, 0x01, 0x00, + 0x00, 0x80, 0xFF, 0xFF, 0xFF, 0xFF}; + auto buffer = ByteSpan(testData.data(), testData.size()); + auto parser = CircularLayoutParser(buffer, config); + EXPECT_THROW(parser.parse(), ParserException); +} + +TEST_F(CircularLayoutParserTest, SingleEvent) { + testData = { + // header + 0xef, 0xbe, 0xad, 0xde, // preamble + 0x01, 0x00, 0x00, 0x00, // program id + 0x03, 0x00, 0x00, 0x00, // hw id + 0x10, 0x00, 0x00, 0x00, // buf size + 0xef, 0xcd, 0xab, 0x89, // initial time + 0x67, 0x45, 0x23, 0x01, // + 0x10, 0x32, 0x54, 0x76, // pre-final time + 0x98, 0xba, 0xdc, 0xfe, // + 0x08, 0x07, 0x06, 0x05, // post-final time + 0x04, 0x03, 0x02, 0x01, // + // num events + 0xff, 0x00, 0x00, 0x00, + // profiled data + 0x00, 0x00, 0x00, 0x02, // start + 0x00, 0x10, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x82, // end + 0x00, 0x20, 0x00, 0x00, // + }; + config.numBlocks = 1; + config.totalUnits = 1; + config.scratchMemSize = testData.size(); + config.uidVec = {0}; + auto buffer = ByteSpan(testData.data(), testData.size()); + auto parser = CircularLayoutParser(buffer, config); + parser.parse(); + auto result = parser.getResult(); + EXPECT_EQ(result->blockTraces.size(), 1); + EXPECT_EQ(result->blockTraces[0].blockId, 1); + EXPECT_EQ(result->blockTraces[0].procId, 3); + EXPECT_EQ(result->blockTraces[0].initTime, 0x0123456789abcdef); + EXPECT_EQ(result->blockTraces[0].preFinalTime, 0xfedcba9876543210); + EXPECT_EQ(result->blockTraces[0].postFinalTime, 0x0102030405060708); + EXPECT_EQ(result->blockTraces[0].traces[0].count, 255); + EXPECT_EQ(result->blockTraces[0].traces[0].uid, 0); + EXPECT_EQ(result->blockTraces[0].traces[0].profileEvents.size(), 1); + auto &event = result->blockTraces[0].traces[0].profileEvents[0]; + EXPECT_EQ(event.first->scopeId, 4); + EXPECT_EQ(event.second->scopeId, 4); + EXPECT_EQ(event.first->isStart, true); + EXPECT_EQ(event.second->isStart, false); + EXPECT_EQ(event.first->cycle, 4096); + EXPECT_EQ(event.second->cycle, 8192); +} + +TEST_F(CircularLayoutParserTest, StartAfterStart) { + testData = { + // header + 0xef, 0xbe, 0xad, 0xde, // preamble + 0x01, 0x00, 0x00, 0x00, // program id + 0x03, 0x00, 0x00, 0x00, // hw id + 0x10, 0x00, 0x00, 0x00, // buf size + 0xef, 0xcd, 0xab, 0x89, // initial time + 0x67, 0x45, 0x23, 0x01, // + 0x10, 0x32, 0x54, 0x76, // pre-final time + 0x98, 0xba, 0xdc, 0xfe, // + 0x08, 0x07, 0x06, 0x05, // post-final time + 0x04, 0x03, 0x02, 0x01, // + // num events + 0xff, 0x00, 0x00, 0x00, + // profiled data + 0x04, 0x00, 0x00, 0x00, // start + 0x00, 0x10, 0x00, 0x00, // + 0x04, 0x00, 0x00, 0x00, // start + 0x00, 0x20, 0x00, 0x00, // + }; + config.numBlocks = 1; + config.totalUnits = 1; + config.scratchMemSize = testData.size(); + config.uidVec = {0}; + auto buffer = ByteSpan(testData.data(), testData.size()); + auto parser = CircularLayoutParser(buffer, config); + parser.parse(); + auto result = parser.getResult(); + EXPECT_EQ(result->blockTraces[0].traces[0].profileEvents.size(), 0); +} + +TEST_F(CircularLayoutParserTest, MultipleSegment) { + testData = { + // header + 0xef, 0xbe, 0xad, 0xde, // preamble + 0x01, 0x00, 0x00, 0x00, // program id + 0x03, 0x00, 0x00, 0x00, // hw id + 0x30, 0x00, 0x00, 0x00, // buf size + 0xef, 0xcd, 0xab, 0x89, // initial time + 0x67, 0x45, 0x23, 0x01, // + 0x10, 0x32, 0x54, 0x76, // pre-final time + 0x98, 0xba, 0xdc, 0xfe, // + 0x08, 0x07, 0x06, 0x05, // post-final time + 0x04, 0x03, 0x02, 0x01, // + // num events + 0xff, 0x00, 0x00, 0x00, // segment 0 + 0xff, 0x00, 0x00, 0x00, // segment 1 + 0xff, 0x00, 0x00, 0x00, // segment 2 + // segment 0 + 0x00, 0x00, 0x00, 0x00, // start + 0x00, 0x10, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x80, // end + 0x00, 0x20, 0x00, 0x00, // + // segment 1 + 0x00, 0x00, 0x00, 0x00, // start + 0x00, 0x10, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x80, // end + 0x00, 0x20, 0x00, 0x00, // + // segment 2 + 0x00, 0x00, 0x00, 0x00, // start + 0x00, 0x10, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x80, // end + 0x00, 0x20, 0x00, 0x00, // + // extra + 0xff, 0xff, 0xff, 0xff, // + 0xff, 0xff, 0xff, 0xff, // + }; + config.numBlocks = 1; + config.totalUnits = 3; + config.scratchMemSize = testData.size(); + config.uidVec = {0, 1, 2}; + auto buffer = ByteSpan(testData.data(), testData.size()); + auto parser = CircularLayoutParser(buffer, config); + parser.parse(); + auto result = parser.getResult(); + EXPECT_EQ(result->blockTraces[0].traces.size(), 3); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(result->blockTraces[0].traces[i].profileEvents.size(), 1); + EXPECT_EQ(result->blockTraces[0].traces[i].profileEvents[0].first->cycle, + 4096); + EXPECT_EQ(result->blockTraces[0].traces[i].profileEvents[0].second->cycle, + 8192); + } +} + +class CLParserSeqTraceTest : public CircularLayoutParserTest { +public: + CLParserSeqTraceTest() : CircularLayoutParserTest("seq") {} +}; + +TEST_F(CLParserSeqTraceTest, Trace) { + auto buffer = getBuffer(output); + auto result = proton::readCircularLayoutTrace(buffer); + EXPECT_EQ(result->blockTraces.size(), 2); + EXPECT_EQ(result->blockTraces[1].blockId, 1); + EXPECT_EQ(result->blockTraces[0].traces.size(), 4); + EXPECT_EQ(result->blockTraces[0].traces[0].count, 12); + EXPECT_EQ(result->blockTraces[0].traces[3].profileEvents.size(), 3); +} + +class CLParserLoopTraceTest : public CircularLayoutParserTest { +public: + CLParserLoopTraceTest() : CircularLayoutParserTest("loop") {} +}; + +TEST_F(CLParserLoopTraceTest, Trace) { + auto buffer = getBuffer(output); + auto result = proton::readCircularLayoutTrace(buffer); + EXPECT_EQ(result->blockTraces.size(), 1); + EXPECT_EQ(result->blockTraces[0].traces.size(), 4); + EXPECT_EQ(result->blockTraces[0].traces[0].count, 80); + EXPECT_EQ(result->blockTraces[0].traces[3].profileEvents.size(), 4); +} + +TEST_F(CircularLayoutParserTest, TimeShift) { + testData = { + // header + 0xef, 0xbe, 0xad, 0xde, // preamble + 0x01, 0x00, 0x00, 0x00, // program id + 0x03, 0x00, 0x00, 0x00, // hw id + 0x20, 0x00, 0x00, 0x00, // buf size + 0xef, 0xcd, 0xab, 0x89, // initial time + 0x67, 0x45, 0x23, 0x01, // + 0x10, 0x32, 0x54, 0x76, // pre-final time + 0x98, 0xba, 0xdc, 0xfe, // + 0x08, 0x07, 0x06, 0x05, // post-final time + 0x04, 0x03, 0x02, 0x01, // + // num events + 0xff, 0x00, 0x00, 0x00, + // profiled data + 0x00, 0x00, 0x00, 0x00, // event 0 start + 0x21, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x01, // event 0 end + 0x36, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x80, // event 1 start + 0x46, 0x00, 0x00, 0x00, // + 0x00, 0x00, 0x00, 0x81, // event 1 end + 0x64, 0x00, 0x00, 0x00, // + }; + config.numBlocks = 1; + config.totalUnits = 1; + config.scratchMemSize = testData.size(); + config.uidVec = {0}; + config.device.type = DeviceType::CUDA; + auto buffer = ByteSpan(testData.data(), testData.size()); + auto parser = CircularLayoutParser(buffer, config); + parser.parse(); + auto result = parser.getResult(); + auto &event0 = result->blockTraces[0].traces[0].profileEvents[0]; + auto &event1 = result->blockTraces[0].traces[0].profileEvents[1]; + EXPECT_EQ(event0.first->cycle, 33); + EXPECT_EQ(event0.second->cycle, 70); + EXPECT_EQ(event1.first->cycle, 54); + EXPECT_EQ(event1.second->cycle, 100); + + const uint64_t cost = getTimeShiftCost(config); + timeShift(cost, result); + + EXPECT_EQ(event0.first->cycle, 26); + EXPECT_EQ(event0.second->cycle, 49); + EXPECT_EQ(event1.first->cycle, 40); + EXPECT_EQ(event1.second->cycle, 72); +} diff --git a/third_party/mthreads/proton/test/unittest/TraceDataIO/DecoderTest.cpp b/third_party/mthreads/proton/test/unittest/TraceDataIO/DecoderTest.cpp new file mode 100644 index 0000000000..689778bdd3 --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/TraceDataIO/DecoderTest.cpp @@ -0,0 +1,20 @@ +#include "TraceDataIO/EntryDecoder.h" +#include +#include +#include + +using namespace proton; + +TEST(DecoderTest, Decode) { + std::vector testData = {0x78, 0x56, 0x34, 0x12, 0x01, 0x00, + 0x00, 0x80, 0xFF, 0xFF, 0xFF, 0xFF}; + + auto buf = ByteSpan(testData.data(), testData.size()); + auto decoder = EntryDecoder(buf); + auto entry1 = decoder.decode(); + EXPECT_EQ(entry1->value, 0x12345678); + auto entry2 = decoder.decode(); + EXPECT_EQ(entry2->isStart, false); + EXPECT_EQ(entry2->scopeId, 0); + EXPECT_EQ(entry2->cycle, 8589934591); +} diff --git a/third_party/mthreads/proton/test/unittest/util/loop.bin b/third_party/mthreads/proton/test/unittest/util/loop.bin new file mode 100644 index 0000000000..2bcb38c109 Binary files /dev/null and b/third_party/mthreads/proton/test/unittest/util/loop.bin differ diff --git a/third_party/mthreads/proton/test/unittest/util/seq.bin b/third_party/mthreads/proton/test/unittest/util/seq.bin new file mode 100644 index 0000000000..4979a9e4dc Binary files /dev/null and b/third_party/mthreads/proton/test/unittest/util/seq.bin differ diff --git a/third_party/mthreads/proton/test/unittest/util/trace_gen.py b/third_party/mthreads/proton/test/unittest/util/trace_gen.py new file mode 100644 index 0000000000..e39cac03a9 --- /dev/null +++ b/third_party/mthreads/proton/test/unittest/util/trace_gen.py @@ -0,0 +1,74 @@ +import triton +import argparse +import ctypes +import triton.profiler as proton +import triton.profiler.language as pl +from triton.profiler.hooks import InstrumentationHook + +pl.enable_semantic("triton") + + +def write_tensor_to_file(tensor, filename): + data_ptr = tensor.data_ptr() + size = tensor.numel() + dtype_size = tensor.element_size() + total_bytes = size * dtype_size + + with open(filename, 'wb') as f: + data_arr = ctypes.cast(data_ptr, ctypes.POINTER(ctypes.c_ubyte * total_bytes)) + f.write(bytes(data_arr.contents)) + + +@triton.jit +def seq_kernel(): + pl.enter_scope("r0") + pl.enter_scope("r1") + pl.enter_scope("r2") + pl.exit_scope("r1") + pl.exit_scope("r0") + pl.exit_scope("r2") + + +def seq(args): + grid_size = 2 + grid = (grid_size, ) + proton.start("", backend="instrumentation", mode=proton.mode.Default(buffer_size=256)) + InstrumentationHook.enable_host_buffer = True + InstrumentationHook.profile_buffer_size = 512 + seq_kernel[grid]() + write_tensor_to_file(InstrumentationHook.host_buffer, args.trace_file) + proton.finalize() + + +@triton.jit +def loop_kernel(): + for k in range(0, 20): + pl.enter_scope("r0") + pl.exit_scope("r0") + + +def loop(args): + grid_size = 1 + grid = (grid_size, ) + proton.start("", backend="instrumentation", mode=proton.mode.Default(buffer_size=256)) + InstrumentationHook.enable_host_buffer = True + InstrumentationHook.profile_buffer_size = 512 + loop_kernel[grid]() + write_tensor_to_file(InstrumentationHook.host_buffer, args.trace_file) + proton.finalize() + + +def main(): + parser = argparse.ArgumentParser(description='Proton intra kernel profiler trace generator') + parser.add_argument('trace_file', type=str, help='Trace file path') + parser.add_argument('--kernel', '-k', type=str, help='Kernel name') + args = parser.parse_args() + + if args.kernel == "seq": + seq(args) + if args.kernel == "loop": + loop(args) + + +if __name__ == '__main__': + main() diff --git a/third_party/mthreads/proton/tutorials/dynamic-net.py b/third_party/mthreads/proton/tutorials/dynamic-net.py new file mode 100644 index 0000000000..8a933d200f --- /dev/null +++ b/third_party/mthreads/proton/tutorials/dynamic-net.py @@ -0,0 +1,103 @@ +import random +import torch +import math + +import triton.profiler as proton +import argparse + +engine = "torch" + + +class DynamicNet(torch.nn.Module): + # https://pytorch.org/tutorials/beginner/examples_nn/dynamic_net.html + def __init__(self): + """ + In the constructor we instantiate five parameters and assign them as members. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + self.e = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + For the forward pass of the model, we randomly choose either 4, 5 + and reuse the e parameter to compute the contribution of these orders. + + Since each forward pass builds a dynamic computation graph, we can use normal + Python control-flow operators like loops or conditional statements when + defining the forward pass of the model. + + Here we also see that it is perfectly safe to reuse the same parameter many + times when defining a computational graph. + """ + y = self.a + self.b * x + self.c * x**2 + self.d * x**3 + for exp in range(4, random.randint(4, 6)): + y = y + self.e * x**exp + return y + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f"y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?" + + +def run(): + # Create Tensors to hold input and outputs. + with proton.scope("init"): + x = torch.linspace(-math.pi, math.pi, 2000, device="cuda") + y = torch.sin(x) + + # Construct our model by instantiating the class defined above + model = DynamicNet().to("cuda") + if engine == "torchinductor": + model = torch.compile(model) + + # Construct our loss function and an Optimizer. Training this strange model with + # vanilla stochastic gradient descent is tough, so we use momentum + criterion = torch.nn.MSELoss(reduction="sum") + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) + for t in range(1000): + # Forward pass: Compute predicted y by passing x to the model + with proton.scope("forward"): + y_pred = model(x) + + # Compute and print loss + with proton.scope("loss"): + loss = criterion(y_pred, y) + if t % 200 == 199: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + with proton.scope("backward"): + optimizer.zero_grad() + loss.backward() + with proton.scope("optimizer"): + optimizer.step() + + print(f"Result: {model.string()}") + + +argparser = argparse.ArgumentParser() +argparser.add_argument("--profile", action="store_true") +argparser.add_argument("--engine", default="torch", choices=["torch", "torchinductor"]) +argparser.add_argument("--context", default="shadow", choices=["shadow", "python"]) +argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer"]) +argparser.add_argument("--mode", default=None) + +args = argparser.parse_args() + +engine = args.engine + +if args.profile: + func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend, mode=args.mode) +else: + func = run + +func() +# Write out the profile +# Visualize using `proton-viewer -m time/s ./dynamic_net.hatchet` +proton.finalize() diff --git a/third_party/mthreads/proton/tutorials/intra_kernel/README.md b/third_party/mthreads/proton/tutorials/intra_kernel/README.md new file mode 100644 index 0000000000..5316135cc2 --- /dev/null +++ b/third_party/mthreads/proton/tutorials/intra_kernel/README.md @@ -0,0 +1,123 @@ +# Proton Intra-Kernel Profiler Tutorial + +A comprehensive tutorial demonstrating how to use the Proton intra-kernel profiler for detailed performance analysis of GPU kernels written in Triton DSL and Gluon DSL. + +## Overview + +The Proton intra-kernel profiler captures fine-grained timing information within GPU kernels, enabling performance bottleneck identification and optimization opportunities. This tutorial provides two distinct profiling approaches: + +- **TTGIR Override Approach** - For profiling existing Triton DSL kernels by injecting instrumentation +- **Proton DSL Approach** - For native integration with Triton and Gluon DSL kernels using embedded profiling scopes + +## Examples + +### 1. TTGIR Override Approach (`example_override.py`) + +**Use Case**: Profile existing Triton DSL kernels without modifying source code + +**Example**: Vector addition kernel with external instrumentation injection + +**Workflow**: +1. **Generate TTGIR dump files**: + ```bash + ../../scripts/dump_ttgir.sh python3 example_override.py --increase-accuracy + ``` + Creates original TTGIR files in `ttgir_dump/` directory + +2. **Insert profiling instrumentation**: + ```bash + ./insert_proton_records + ``` + Modifies TTGIR files by adding `proton.record` operators at profiling points + +3. **Execute with TTGIR override**: + ```bash + TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 TRITON_OVERRIDE_DIR=ttgir_dump python3 example_override.py --increase-accuracy + ``` + - `TRITON_ALWAYS_COMPILE=1`: Forces recompilation on each run + - `TRITON_KERNEL_OVERRIDE=1`: Enables TTGIR override mechanism + - `TRITON_OVERRIDE_DIR=ttgir_dump`: Specifies directory with modified TTGIR files + +### 2. Proton DSL Approach (`example_dsl.py`) + +**Use Case**: Native profiling DSL integration for Triton and Gluon DSL kernels + +**Example**: Triton vector-add and Gluon matrix multiplication using NVIDIA Hopper architecture features (WGMMA, TMA) + + +**Command Line Options**: +```bash +# Timeline trace mode (default) +python3 example_dsl.py + +# Operation measurement mode +python3 example_dsl.py --op-measure + +# Enable warp sampling with specific warp IDs +python3 example_dsl.py --warp-sampling --warp-ids "0,1,2,3" --gmem_buffer + +# High accuracy profiling +python3 example_dsl.py --increase-accuracy +``` + +## Understanding Timeline Traces + +### Time Representation + +- **Scope Duration**: Displayed in cycles for precise measurement +- **Threadblock Start Times**: Measured in nanoseconds using global timing +- **Chrome Trace Format**: Assumes 1GHz GPU frequency for consistent time units (ns) + +### Circular Buffer System + +- **Backend Storage**: Uses circular buffer for runtime profiling on each CTA +- **Buffer Overflow**: When full, earlier events are dropped with warnings in trace generation +- **Event Window**: Displays sliding window (the latest window) of recorded events in timeline + +### Finalize Time Measurement + +- **Definition**: Captures `Finalize Time` when kernel execution completes +- **Meaning**: Shows overhead of dumping profiling data from buffer to global memory (appears as a field in Chrome trace viewer tab) + +## Configuration Options + +### Profiling Accuracy + +| Option | Description | Use Case | +|--------|-------------|----------| +| `clock32` | Records events in 32-bit clock format for lower overhead | normal kernels (<4 seconds @ 1GHz) | +| `time_shift` | Deducts constant profiling overhead from timeline trace | Mitigate Proton runtime overhead for cleaner traces | +| `sched_stores` | Provides more cycle-accurate operation latency measurement | Accurate single operation latency measure | +| `sched_barriers` | Constrains AMD instruction scheduling within proton scopes | AMD GPU profiling | + +### Buffer Configuration + +| Buffer Type | Options | Default | Description | +|-------------|---------|---------|-------------| +| `buffer_type` | `shared`, `global` | `shared` | Determines whether profiling data is stored in shared or global memory | +| `buffer_size` | Integer | `shared`: Maximum size without reducing occupancy; `global`: 16KB × number of profiled units (e.g., warp) | Controls per-block profiling buffer size in bytes | + +### Sampling Configuration + +| Parameter | Options | Description | +|-----------|---------|-------------| +| `sampling_strategy` | `selective`, `none` | Sampling approach for profiling data collection | +| `sampling_options` | Comma-separated warp IDs | Specific warps to profile (e.g., "0,1,2,3") | + +**Sampling Benefits**: Warp sampling captures more events within the same buffer size constraint by focusing on specific warps of interest. + +## Output Formats + +### Timeline Traces + +- **Format**: Chrome trace format (`.chrome_trace` files) +- **Viewer**: Chrome browser at `chrome://tracing` or [`Perfetto`](https://ui.perfetto.dev/) +- **Content**: Detailed timeline with scope durations + +### Operation Measurements + +- **Format**: Hatchet format (`.hatchet` files) +- **Viewer**: `proton-viewer -m normalized_cycles .hatchet` +(with `-m cycles` showing sum of all cycles across the GPU, `normalized_cycles` for per-warp averaged cycles) +- **Content**: Scope-level performance metrics and statistics +- **Note**: Cycle counts are averaged across warps/CTAs diff --git a/third_party/mthreads/proton/tutorials/intra_kernel/example_dsl.py b/third_party/mthreads/proton/tutorials/intra_kernel/example_dsl.py new file mode 100644 index 0000000000..46d5853492 --- /dev/null +++ b/third_party/mthreads/proton/tutorials/intra_kernel/example_dsl.py @@ -0,0 +1,317 @@ +""" +Intra-Kernel Profiling Examples using Proton DSL for Triton and Gluon Kernels +""" + +import argparse + +import torch +import triton +import triton.language as tl +import triton.profiler as proton +import triton.profiler.language as pl +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +from triton.experimental.gluon.language.nvidia.hopper import ( + fence_async_shared, + mbarrier, + tma, + warpgroup_mma, + warpgroup_mma_init, + warpgroup_mma_wait, +) + +from triton.experimental.gluon.nvidia.hopper import TensorDescriptor + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +NUM_WARPS = 8 + + +def is_hopper(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 9 + + +def config_helper(description: str): + # Configure command line arguments for profiling options + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "--op-measure", + action="store_true", + default=False, + help="Enable operation measurement. Otherwise, we profile timeline trace. (default: False)", + ) + parser.add_argument( + "--warp-sampling", + action="store_true", + default=False, + help="Enable warp sampling during profiling (default: False)", + ) + parser.add_argument( + "--increase-accuracy", + action="store_true", + default=False, + help="Enable increased-accuracy during profiling (default: False).", + ) + parser.add_argument( + "--warp-ids", + type=str, + default="0, 2", + help="Comma-separated list of warp IDs for warp sampling (default: '0, 2')", + ) + parser.add_argument( + "--gmem_buffer", + action="store_true", + default=False, + help="Use global memory as the internal buffer during profiling (default: False).", + ) + + args = parser.parse_args() + + # Configure profiling options based on accuracy requirements + # Default uses clock_64 for long-running kernels with higher overhead + opts = "" + # `clock_32` provides lower overhead per record, `time_shift`` post-processes to reduce noise + if args.increase_accuracy: + opts = "clock32,time_shift" + + if args.gmem_buffer: + buf = "global" + else: + buf = "shared" + + # Set up profiling mode based on warp sampling preferences + if args.warp_sampling: + # Selective warp sampling allows capturing more events within buffer constraints + # by only profiling specified warps (e.g. "0,1,2,3") + mode = proton.mode.Default( + optimizations=opts, + sampling_strategy="selective", + sampling_options=args.warp_ids, + buffer_type=buf, + ) + else: + # Profile all warps - provides complete picture but uses more buffer space + mode = proton.mode.Default(optimizations=opts, buffer_type=buf) + + return args.op_measure, mode + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pl.enter_scope("kernel") + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + with pl.scope("load_and_add"): + with pl.scope("load_x_issue"): + x = tl.load(x_ptr + offsets, mask=mask) + with pl.scope("load_y_issue"): + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + pl.exit_scope("kernel") + + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=NUM_WARPS) + return output + + +if __name__ == "__main__": + description = "Triton Vector Add with Proton Intra-Kernel Profiling" + print(description) + + # Explicit Proton DSL enablement for Triton kernels. + # Be careful NOT to insert proton ops in loops (use the ttgir override approach instead). + pl.enable_semantic("triton") + + op_measure, mode = config_helper(description) + + # Start profiling with appropriate backend and output format + if op_measure: + # Operation measurement mode generates scope-level metrics + # View results with: proton-viewer -m normalized_cycles vector-add.hatchet + # Note: cycles are averaged across all warps/CTAs - adjust for warp specialization + proton.start("vector-add", backend="instrumentation", mode=mode) + else: + # Timeline trace mode generates Chrome trace format for visualization + # Output file: vector-add.chrome_trace + proton.start("vector-add", data="trace", backend="instrumentation", mode=mode) + + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=DEVICE) + y = torch.rand(size, device=DEVICE) + output_torch = x + y + output_triton = add(x, y) + torch.testing.assert_close(output_torch, output_triton, rtol=1e-3, atol=1e-1) + proton.finalize() + + +# This decorator allows us to invoke the function from a Gluon constexpr. +@gluon.constexpr_function +def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps): + warps_per_cta = [4, 1] + m = 16 + # Tile the atom until we have enough warps. + while warps_per_cta[0] * warps_per_cta[1] != num_warps: + # Tile along M only if it would not cause broadcasting. + if BLOCK_M > m * warps_per_cta[0]: + warps_per_cta[0] *= 2 + else: + warps_per_cta[1] *= 2 + return warps_per_cta + + +@gluon.constexpr_function +def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps): + m = 16 + mReps = triton.cdiv(BLOCK_M, m) + nReps = triton.cdiv(num_warps, mReps) + maxN = max(BLOCK_N // nReps, 8) + n = 256 + while n > maxN or BLOCK_N % n != 0: + n -= 8 + assert n >= 8, "expected to find a valid n" + return n + + +@gluon.constexpr_function +def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps): + m = 16 + k = 256 // dtype.primitive_bitwidth + n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps) + warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps) + return gl.NVMMADistributedLayout( + version=[3, 0], + warps_per_cta=warps_per_cta, + instr_shape=[m, n, k], + ) + + +@gluon.jit +def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.constexpr): + BLOCK_M: gl.constexpr = c_desc.block_type.shape[0] + BLOCK_N: gl.constexpr = c_desc.block_type.shape[1] + BLOCK_K: gl.constexpr = a_desc.block_type.shape[1] + dtype: gl.constexpr = a_desc.dtype + K = a_desc.shape[1] + + pl.enter_scope("blocked_matmul_pipelined_kernel") + + # Allocate 2 buffers for each A and B. + a_smem = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout) + b_smem = gl.allocate_shared_memory(dtype, [2] + b_desc.block_type.shape, b_desc.layout) + index = 0 + + pid_m = gl.program_id(axis=0) + pid_n = gl.program_id(axis=1) + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + + mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps) + acc = warpgroup_mma_init(gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout)) + + bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + phase = 0 + + for k in range(0, K, BLOCK_K): + a = a_smem.index(index) + b = b_smem.index(index) + + mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) + + with pl.scope("tma_loads_issue"): + tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a) + tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b) + + with pl.scope("tma_loads_wait"): + mbarrier.wait(bar, phase=phase) + phase ^= 1 + + # Since `warpgroup_mma_wait` is a no-op when there are no WGMMAs in + # flight, we can overlap the WGMMA by waiting first, then issuing the + # async WGMMA. + with pl.scope("wgmma_wait"): + acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, )) + + with pl.scope("wgmma_issue"): + acc = warpgroup_mma(a, b, acc, is_async=True) + + # Move to the next buffer. The TMA load will start while the WGMMA is + # still running. + index ^= 1 + + # Wait for the last WGMMA to complete. + with pl.scope("wgmma_last_wait"): + acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, )) + + mbarrier.invalidate(bar) + + c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout) + c_smem.store(acc.to(dtype)) + fence_async_shared() + tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem) + tma.store_wait(pendings=0) + + pl.exit_scope("blocked_matmul_pipelined_kernel") + + +def blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps): + M, N = C.shape + + a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16) + b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16) + c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16) + a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout) + b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout) + c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout) + + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + blocked_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, num_warps=num_warps) + + +if __name__ == "__main__": + if not is_hopper(): + raise RuntimeError("This tutorial requires a Hopper NVIDIA GPU") + + description = "Gluon Matrix Multiplication with Proton Intra-Kernel Profiling" + print(description) + + M, N, K = 512, 512, 1024 + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty(M, N, device="cuda", dtype=torch.float16) + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + + op_measure, mode = config_helper(description) + + # Start profiling with appropriate backend and output format + if op_measure: + # Operation measurement mode generates scope-level metrics + # View results with: proton-viewer -m normalized_cycles gemm.hatchet + # Note: cycles are averaged across all warps/CTAs - adjust for warp specialization + proton.start("gemm", backend="instrumentation", mode=mode) + else: + # Timeline trace mode generates Chrome trace format for visualization + # Output file: gemm.chrome_trace + proton.start("gemm", data="trace", backend="instrumentation", mode=mode) + + blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS) + torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1) + + # Complete profiling and write output files + proton.finalize() diff --git a/third_party/mthreads/proton/tutorials/intra_kernel/example_override.py b/third_party/mthreads/proton/tutorials/intra_kernel/example_override.py new file mode 100644 index 0000000000..a87e8e8868 --- /dev/null +++ b/third_party/mthreads/proton/tutorials/intra_kernel/example_override.py @@ -0,0 +1,98 @@ +""" +Vector Addition with Triton Intra-Kernel Profiling using TTGIR Override + +This tutorial demonstrates how to use Triton's TTGIR override mechanism +to enable intra-kernel profiling with Proton. The workflow involves generating, +modifying, and overriding the kernel's intermediate representation to insert +profiling hooks. + +Workflow: +1. Generate TTGIR dump files: + + This creates the original TTGIR files in the `ttgir_dump/` directory: + + ../../scripts/dump_ttgir.sh python3 example_override.py --increase-accuracy + +2. Insert profiling instrumentation: + + Modify the generated TTGIR files by adding proton.record operators at desired + profiling points. Example script that adds proton ops in the above ttgir: + + ./insert_proton_records + +3. Execute with TTGIR override: + + TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_OVERRIDE=1 TRITON_OVERRIDE_DIR=ttgir_dump python3 example_override.py --increase-accuracy + + - TRITON_ALWAYS_COMPILE=1: Forces recompilation on each run + - TRITON_KERNEL_OVERRIDE=1: Enables TTGIR override mechanism + - TRITON_OVERRIDE_DIR=ttgir_dump: Specifies directory containing modified TTGIR files +""" + +import argparse + +import torch +import triton +import triton.language as tl +import triton.profiler as proton +from triton.profiler.mode import Default + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + parser = argparse.ArgumentParser(description="TTGIR override example with Triton intra kernel profiling") + parser.add_argument( + "--increase-accuracy", + action="store_true", + default=False, + help="Enable increased-accuracy during profiling (default: False)", + ) + args = parser.parse_args() + + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + + if args.increase_accuracy: + proton.start( + "add", + data="trace", + backend="instrumentation", + mode=Default(optimizations="clock32,time_shift"), + ) + else: + proton.start("add", data="trace", backend="instrumentation") + + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + + proton.finalize() + return output + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) +output_torch = x + y +output_triton = add(x, y) +torch.testing.assert_close(output_torch, output_triton, rtol=1e-3, atol=1e-1) diff --git a/third_party/mthreads/proton/tutorials/intra_kernel/insert_proton_records b/third_party/mthreads/proton/tutorials/intra_kernel/insert_proton_records new file mode 100755 index 0000000000..e98435e06e --- /dev/null +++ b/third_party/mthreads/proton/tutorials/intra_kernel/insert_proton_records @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Script to automatically add proton.record statements to the examplar vector-add ttgir. +""" + +import glob +import os +import re +import sys + + +def add_proton_records(input_file): + """Add proton.record statements to a ttgir file.""" + + with open(input_file, "r") as f: + content = f.read() + lines = f.readlines() + + # Assert no proton.record already exists + if "proton.record" in content: + raise AssertionError("File already contains `proton.record` statements! Please clean-up.") + + # Reset file pointer and read lines again + with open(input_file, "r") as f: + lines = f.readlines() + + result_lines = [] + load_and_add_started = False + + for i, line in enumerate(lines): + # Add kernel record start after function declaration + if "tt.func public @" in line and "{" in line: + result_lines.append(line) + result_lines.append(' proton.record start "kernel"\n') + continue + + # Add load_and_add record start before first load + if "tt.load" in line and not load_and_add_started: + result_lines.append(' proton.record start "load_and_add"\n') + load_and_add_started = True + + # Add individual load records + if "tt.load" in line: + # Extract variable name (x, y, etc.) - just the letters before '_' + match = re.search(r"%(\w+)_\d+\s*=\s*tt\.load", line) + if match: + var_name = match.group(1) + result_lines.append(f' proton.record start "load_{var_name}_issue"\n') + result_lines.append(line) + result_lines.append(f' proton.record end "load_{var_name}_issue"\n') + continue + + # Add load_and_add record end after arithmetic operation + if "arith.addf" in line and load_and_add_started: + result_lines.append(line) + result_lines.append(' proton.record end "load_and_add"\n') + load_and_add_started = False + continue + + # Add kernel record end before return + if "tt.return" in line: + result_lines.append(' proton.record end "kernel"\n') + result_lines.append(line) + continue + + # Default: just add the line + result_lines.append(line) + + # Write output in-place + with open(input_file, "w") as f: + f.writelines(result_lines) + + print(f"Added proton records to {input_file}") + + +def find_and_process_ttgir(): + """Find all ttgir files in ttgir_dump directory and process them.""" + + # Find ttgir_dump directory + ttgir_dump_path = None + for root, dirs, files in os.walk("."): + if "ttgir_dump" in dirs: + ttgir_dump_path = os.path.join(root, "ttgir_dump") + break + + if not ttgir_dump_path: + print("Error: ttgir_dump directory not found!") + sys.exit(1) + + # Process the ttgir file + ttgir_files = glob.glob(os.path.join(ttgir_dump_path, "**", "*.ttgir"), recursive=True) + + if not ttgir_files: + print(f"No ttgir files found in {ttgir_dump_path}") + return + + if len(ttgir_files) > 1: + print(f"Warning: Found {len(ttgir_files)} ttgir files, expected at most 1") + + ttgir_file = ttgir_files[0] # Take the first (and expected only) file + try: + print(f"Processing {ttgir_file}...") + add_proton_records(ttgir_file) + print("Successfully processed ttgir file") + except AssertionError as e: + print(f"Skipping {ttgir_file}: {e}") + except Exception as e: + print(f"Error processing {ttgir_file}: {e}") + + +if __name__ == "__main__": + find_and_process_ttgir() diff --git a/third_party/mthreads/proton/tutorials/matmul.py b/third_party/mthreads/proton/tutorials/matmul.py new file mode 100644 index 0000000000..e5e09d1dda --- /dev/null +++ b/third_party/mthreads/proton/tutorials/matmul.py @@ -0,0 +1,318 @@ +import torch + +import triton +import triton.language as tl +import triton.profiler as proton +from typing import NamedTuple +import argparse + + +def unpack_grid(grid): + if len(grid) == 1: + return grid[0], 1, 1 + if len(grid) == 2: + return grid[0], grid[1], 1 + if len(grid) == 3: + return grid[0], grid[1], grid[2] + + +def metadata_fn( + grid: tuple, + metadata: NamedTuple, + args: dict, +): + grid_x, grid_y, grid_z = unpack_grid(grid) + num_warps = metadata.num_warps + num_stages = metadata.num_stages + cluster_x, cluster_y, cluster_z = unpack_grid((metadata.num_ctas, )) + shared_memory = metadata.shared + M, K = args["a_ptr"].shape + K, N = args["b_ptr"].shape + return { + "name": + f"matmul_____", + "flops": 2 * M * N * K, + "bytes": (M * N + N * K + K * M) * args["a_ptr"].element_size(), + } + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=metadata_fn) +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +argparser = argparse.ArgumentParser() +argparser.add_argument("--profile", action="store_true") +argparser.add_argument("--pcsampling", action="store_true", default=False) +argparser.add_argument("--cudagraph", action="store_true", default=False) +args = argparser.parse_args() + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 10)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=["cublas", "triton"], + # Label name for the lines + line_names=["cuBLAS", "Triton"], + # Line styles + styles=[("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + with proton.scope(f"matmul_{M}_{N}_{K}"): + if provider == "cublas": + + @proton.scope( + "cublas", + metrics={ + "flops": 2 * M * N * K, + "bytes": (M * N + N * K + K * M) * a.element_size(), + }, + ) + def cublas_matmul(a, b): + torch.matmul(a, b) + + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: cublas_matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: cublas_matmul(a, b), quantiles=quantiles) + if provider == "triton": + + def enter_autotune(args, reset_only=False): + if reset_only: + return + proton.enter_scope("") + + def exit_autotune(args, exception): + proton.exit_scope() + + matmul_kernel.pre_hook = enter_autotune + matmul_kernel.post_hook = exit_autotune + with proton.scope("triton"): + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + + return perf(ms), perf(max_ms), perf(min_ms) + + +if args.profile: + if args.pcsampling: + # proton-viewer -m num_samples/%,time/s ./matmul.hatchet + proton.start("matmul", hook="triton", backend="cupti", mode="pcsampling") + else: + # proton-viewer -m tflop/s,time/s ./matmul.hatchet + proton.start("matmul", hook="triton") + benchmark.run(show_plots=True, print_data=True) + proton.finalize() +else: + benchmark.run(show_plots=True, print_data=True) diff --git a/third_party/mthreads/python/src/gluon_ir.cc b/third_party/mthreads/python/src/gluon_ir.cc new file mode 100644 index 0000000000..c5d76bb609 --- /dev/null +++ b/third_party/mthreads/python/src/gluon_ir.cc @@ -0,0 +1,1189 @@ +#include "ir.h" +#include "pybind11/pybind11.h" +#include + +#include +#include + +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#if TRITON_ENABLE_AMD +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#endif +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/GenericSwizzling.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +namespace py = pybind11; +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; +namespace gluon = mlir::triton::gluon; +#if TRITON_ENABLE_AMD +namespace ttag = mlir::triton::amdgpu; +#endif + +static ttg::CGAEncodingAttr +buildCgaLayoutAttr(MLIRContext *ctx, + const std::vector> &layout, + unsigned rank) { + auto kBlock = StringAttr::get(ctx, "block"); + tt::LinearLayout::BasesT bases; + bases[kBlock] = layout; + auto outDims = tt::standardOutDimNames(ctx, rank); + tt::LinearLayout ll(std::move(bases), outDims); + return ttg::CGAEncodingAttr::get(ctx, std::move(ll)); +} + +static std::vector> +getCgaLayoutBases(ttg::CGAEncodingAttr layout) { + std::vector> result; + auto ctx = layout.getContext(); + auto block = StringAttr::get(ctx, "block"); + const auto &basesMap = layout.getLinearLayout().getBases(); + auto it = basesMap.find(block); + assert(it != basesMap.end()); + return it->second; +} + +// Helper to check if an MLIR type or attribute has a verifier method. +template +static constexpr auto hasVerifier(AttrOrType t) + -> decltype(t.verifyInvariants, true) { + return true; +} +static constexpr auto hasVerifier(...) { return false; } + +// Print a diagnostic without its location. The frontend will attach the AST +// location to the error message. +static void printDiagStr(llvm::raw_ostream &os, const Diagnostic &diag) { + for (const DiagnosticArgument &arg : diag.getArguments()) + arg.print(os); + os << "\n"; + for (const Diagnostic ¬e : diag.getNotes()) + printDiagStr(os, note); +} + +struct GluonOpBuilder : public TritonOpBuilder { + using TritonOpBuilder::TritonOpBuilder; + // Construct an attribute or type while calling its verifier. Error messages + // are intercepted and sent back to Python via a C++ exception. + template + std::enable_if_t + getChecked(ArgTs &&...args) { + // Set up a scoped handler to intercept errors. + std::string msg; + llvm::raw_string_ostream os(msg); + ScopedDiagnosticHandler handler( + getContext(), [&](Diagnostic &diag) { printDiagStr(os, diag); }); + + auto result = + AttrOrType::getChecked([&] { return mlir::emitError(getLastLoc()); }, + std::forward(args)...); + if (!result) + throw std::runtime_error(os.str()); + return result; + } + + // A variant of the above due to issues with C++ overload resolution and how + // MLIR sets up the default `getChecked` implementation. + template + std::enable_if_t + getChecked(MLIRContext *ctx, ArgTs &&...args) { + // Set up a scoped handler to intercept errors. + std::string msg; + llvm::raw_string_ostream os(msg); + ScopedDiagnosticHandler handler( + getContext(), [&](Diagnostic &diag) { printDiagStr(os, diag); }); + + if (failed(AttrOrType::verifyInvariants( + [&] { return mlir::emitError(getLastLoc()); }, args...))) + throw std::runtime_error(os.str()); + + return AttrOrType::get(ctx, std::forward(args)...); + } + + // Fallback method for types or attributes that do not have a verifier. + template + std::enable_if_t + getChecked(ArgTs &&...args) { + return AttrOrType::get(std::forward(args)...); + } +}; + +struct GluonLayouts { + py::handle AutoLayout; + py::handle CoalescedLayout; + py::handle BlockedLayout; + py::handle SliceLayout; + py::handle DistributedLinearLayout; + py::handle DotOperandLayout; + py::handle NVMMADistributedLayout; + py::handle TensorMemoryScalesLayout; + py::handle TensorMemoryLayout; + py::handle NVMMASharedLayout; + py::handle SwizzledSharedLayout; + py::handle SharedLinearLayout; + py::handle AMDMFMALayout; + py::handle AMDWMMALayout; + py::handle PaddedSharedLayout; + + GluonLayouts() { + auto layouts = + py::module::import("triton.experimental.gluon.language._layouts"); + auto amdLayouts = + py::module::import("triton.experimental.gluon.language.amd._layouts"); + auto blackwellLayouts = py::module::import( + "triton.experimental.gluon.language.nvidia.blackwell"); + AutoLayout = py::object(layouts.attr("AutoLayout")).release(); + CoalescedLayout = py::object(layouts.attr("CoalescedLayout")).release(); + BlockedLayout = py::object(layouts.attr("BlockedLayout")).release(); + SliceLayout = py::object(layouts.attr("SliceLayout")).release(); + DistributedLinearLayout = + py::object(layouts.attr("DistributedLinearLayout")).release(); + DotOperandLayout = py::object(layouts.attr("DotOperandLayout")).release(); + NVMMADistributedLayout = + py::object(layouts.attr("NVMMADistributedLayout")).release(); + TensorMemoryScalesLayout = + py::object(blackwellLayouts.attr("TensorMemoryScalesLayout")).release(); + TensorMemoryLayout = + py::object(blackwellLayouts.attr("TensorMemoryLayout")).release(); + NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release(); + SwizzledSharedLayout = + py::object(layouts.attr("SwizzledSharedLayout")).release(); + SharedLinearLayout = + py::object(layouts.attr("SharedLinearLayout")).release(); + AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release(); + AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release(); + PaddedSharedLayout = + py::object(layouts.attr("PaddedSharedLayout")).release(); + + auto core = py::module::import("triton.language.core"); + } +}; + +static bool isConvertLayoutTrivial(RankedTensorType dstTy, Value value) { + auto srcTy = cast(value.getType()); + if (srcTy.getEncoding() == dstTy.getEncoding()) + return true; + // Fail safe on unresolved layouts. + if (isa(srcTy.getEncoding())) + return false; + if (isa(dstTy.getEncoding())) + return false; + + // Check concrete layouts. + triton::LinearLayout cvt = minimalCvtLayout(srcTy, dstTy); + auto dims = llvm::to_vector(cvt.getInDimNames()); + return dims.empty() || (dims.size() == 1 && dims.front() == "register"); +} + +template +std::vector> toStdVector(R &&range) { + return {range.begin(), range.end()}; +} + +py::object layoutToGluon(Attribute layout) { + static GluonLayouts layouts; + if (auto blocked = dyn_cast(layout)) { + auto cgaBases = getCgaLayoutBases(blocked.getCGALayout()); + return layouts.BlockedLayout(toStdVector(blocked.getSizePerThread()), + toStdVector(blocked.getThreadsPerWarp()), + toStdVector(blocked.getWarpsPerCTA()), + toStdVector(blocked.getOrder()), cgaBases); + } else if (auto sliced = dyn_cast(layout)) { + return layouts.SliceLayout(sliced.getDim(), + layoutToGluon(sliced.getParent())); + } else if (auto linear = dyn_cast(layout)) { + const auto &ll = linear.getLinearLayout(); + auto ctx = layout.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + return layouts.DistributedLinearLayout( + ll.getBases().lookup(kReg), ll.getBases().lookup(kLane), + ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock), + toStdVector(ll.getOutDimSizes())); + } else if (auto dotOp = dyn_cast(layout)) { + return layouts.DotOperandLayout( + dotOp.getOpIdx(), layoutToGluon(dotOp.getParent()), dotOp.getKWidth()); + } else if (auto mma = dyn_cast(layout)) { + auto cgaBases = getCgaLayoutBases(mma.getCGALayout()); + return layouts.NVMMADistributedLayout( + std::vector{mma.getVersionMajor(), mma.getVersionMinor()}, + toStdVector(mma.getWarpsPerCTA()), toStdVector(mma.getInstrShape()), + cgaBases); + } else if (auto nvmma = dyn_cast(layout)) { + auto cgaLayout = nvmma.getCGALayout(); + auto cgaBases = getCgaLayoutBases(cgaLayout); + return layouts.NVMMASharedLayout(nvmma.getSwizzlingByteWidth(), + nvmma.getElementBitWidth(), + cgaLayout.getRank(), nvmma.getTransposed(), + nvmma.getFp4Padded(), cgaBases); + } else if (auto swizzled = + dyn_cast(layout)) { + auto cgaBases = getCgaLayoutBases(swizzled.getCGALayout()); + return layouts.SwizzledSharedLayout( + swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(), + toStdVector(swizzled.getOrder()), cgaBases); + } else if (auto sharedLl = dyn_cast(layout)) { + const auto &ll = sharedLl.getLinearLayout(); + auto ctx = layout.getContext(); + auto kOffset = mlir::StringAttr::get(ctx, "offset"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + return layouts.SharedLinearLayout( + toStdVector(ll.getBases().lookup(kOffset)), + toStdVector(ll.getBases().lookup(kBlock)), sharedLl.getAlignment()); + } else if (auto autoEnc = dyn_cast(layout)) { + return layouts.AutoLayout(); + } else if (auto autoEnc = dyn_cast(layout)) { + return layouts.CoalescedLayout(); + } else if (auto amdMfma = dyn_cast(layout)) { + auto cgaBases = getCgaLayoutBases(amdMfma.getCGALayout()); + return layouts.AMDMFMALayout( + amdMfma.getVersion(), toStdVector(amdMfma.getInstrShape()), + amdMfma.getIsTransposed(), toStdVector(amdMfma.getWarpsPerCTA()), + amdMfma.getElementBitWidth(), toStdVector(amdMfma.getTilesPerWarp()), + cgaBases); + } else if (auto amdWmma = dyn_cast(layout)) { + auto cgaBases = getCgaLayoutBases(amdWmma.getCGALayout()); + const auto &ctaLayout = amdWmma.getCtaLayout(); + auto ctx = layout.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + return layouts.AMDWMMALayout( + amdWmma.getVersion(), amdWmma.getIsTransposed(), + ctaLayout.getBases().lookup(kWarp), ctaLayout.getBases().lookup(kReg), + toStdVector(amdWmma.getInstrShape()), cgaBases, amdWmma.getRank()); + } else if (auto paddedShared = + dyn_cast(layout)) { + auto *ctx = paddedShared.getContext(); + std::vector> intervalPaddingPairs; + for (auto [interval, padding] : + llvm::zip(paddedShared.getIntervals(), paddedShared.getPaddings())) { + intervalPaddingPairs.push_back({interval, padding}); + } + auto kOffset = mlir::StringAttr::get(ctx, "offset"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + const auto &ll = paddedShared.getLinearComponent(); + auto shape = toStdVector(ll.getOutDimSizes()); + return layouts.PaddedSharedLayout(intervalPaddingPairs, + ll.getBases().lookup(kOffset), + ll.getBases().lookup(kBlock), shape); + } else if (auto tmemScales = + dyn_cast(layout)) { + return layouts.TensorMemoryScalesLayout(std::vector{ + tmemScales.getCTASplitM(), tmemScales.getCTASplitN()}); + } else if (auto tmem = dyn_cast(layout)) { + return layouts.TensorMemoryLayout( + std::vector{tmem.getBlockM(), tmem.getBlockN()}, + tmem.getColStride(), + std::vector{tmem.getCTASplitM(), tmem.getCTASplitN()}); + } + + throw py::value_error("Unhandled encoding encountered"); +} + +template static void check(CondT &&cond, const char *msg) { + if (!std::forward(cond)) + throw py::value_error(msg); +} + +void init_gluon_ir(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "TMEM_LOAD_REDUCE_MODIFIER", + py::module_local()) + .value("MIN", ttng::TMEMLoadReduceModifier::MIN) + .value("MAX", ttng::TMEMLoadReduceModifier::MAX) + .export_values(); + + py::class_( + m, "GluonOpBuilder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("get_op_builder", &GluonOpBuilder::getBuilder, ret::reference) + .def("get_distributed_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout) -> Type { + return self.getChecked(shape, elementType, + layout); + }) + .def("get_shared_mem_desc_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout, + std::vector &allocShape) -> Type { + auto ctx = self.getContext(); + return self.getChecked( + shape, elementType, layout, + ttg::SharedMemorySpaceAttr::get(ctx), + /*mutableMemory=*/true, + /*allocShape=*/allocShape); + }) + .def("get_tensor_mem_desc_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout, + std::vector &allocShape) -> Type { + auto ctx = self.getContext(); + return self.getChecked( + shape, elementType, layout, + ttng::TensorMemorySpaceAttr::get(ctx), + /*mutableMemory=*/true, + /*allocShape=*/allocShape); + }) + .def("get_blocked_layout", + [](GluonOpBuilder &self, std::vector &sizePerThread, + std::vector &threadsPerWarp, + std::vector &warpsPerCta, std::vector &order, + std::vector> &cgaBases) -> Attribute { + auto ctx = self.getContext(); + unsigned rank = order.size(); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return self.getChecked( + ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, + cgaLayout); + }) + .def("get_slice_layout", + [](GluonOpBuilder &self, unsigned dim, + Attribute parent) -> Attribute { + auto ctx = self.getContext(); + auto dist = cast(parent); + return self.getChecked(ctx, dim, dist); + }) + .def("get_distributed_linear_layout", + [](GluonOpBuilder &self, std::vector> regBases, + std::vector> laneBases, + std::vector> warpBases, + std::vector> blockBases, + std::vector shape) -> Attribute { + auto ctx = self.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + auto outDims = tt::standardOutDimPairs(ctx, shape); + auto ll = tt::LinearLayout({{kReg, regBases}, + {kLane, laneBases}, + {kWarp, warpBases}, + {kBlock, blockBases}}, + outDims, + /*requiresSurjective=*/true); + return ttg::LinearEncodingAttr::get(ctx, std::move(ll)); + }) + .def("to_linear_layout", + [](GluonOpBuilder &self, Attribute layout, + std::vector &shape) -> py::object { + auto ctx = self.getContext(); + auto linearLayout = ttg::toLinearLayout(shape, layout); + + if (isa(layout)) { + auto attr = + ttg::LinearEncodingAttr::get(ctx, std::move(linearLayout)); + return layoutToGluon(attr); + } + if (isa(layout)) { + auto alignment = + cast(layout).getAlignment(); + auto attr = ttg::SharedLinearEncodingAttr::get( + ctx, std::move(linearLayout), alignment); + return layoutToGluon(attr); + } + + // TensorMemory encodings: keep the LinearLayout but wrap as + // print-only Python object carrying row/col bases -> dim0/dim1. + auto inNamesRange = linearLayout.getInDimNames(); + auto inNames = llvm::to_vector(inNamesRange); + bool isTmemLayout = + (inNames.size() == 2 && inNames[0].str() == "row" && + inNames[1].str() == "col"); + if (!isTmemLayout) + throw std::invalid_argument( + "Unsupported layout in to_linear_layout"); + + // Build Py _TensorMemoryLinearLayout(row_bases, col_bases, shape, + // repr) + py::object tmemCls = + py::module::import( + "triton.experimental.gluon.language.nvidia.blackwell") + .attr("_TensorMemoryLinearLayout"); + auto bases = linearLayout.getBases(); + auto rowBases = bases[mlir::StringAttr::get(ctx, "row")]; + auto colBases = bases[mlir::StringAttr::get(ctx, "col")]; + auto outDims = linearLayout.getOutDims(); + std::vector shapeVec; + for (auto &od : outDims) + shapeVec.push_back(od.second); + + py::object pyObj = tmemCls(py::cast(rowBases), py::cast(colBases), + py::cast(shapeVec)); + return pyObj; + }) + .def("get_dot_operand_layout", + [](GluonOpBuilder &self, unsigned opIdx, Attribute parent, + unsigned kWidth) -> Attribute { + return self.getChecked( + self.getContext(), opIdx, parent, kWidth); + }) + .def("get_mma_layout", + [](GluonOpBuilder &self, std::vector &version, + std::vector &warpsPerCta, + std::vector> &cgaBases, + std::vector &instrShape) -> Attribute { + auto ctx = self.getContext(); + unsigned rank = warpsPerCta.size(); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return self.getChecked( + ctx, version[0], version[1], warpsPerCta, cgaLayout, + instrShape); + }) + .def("get_amd_mfma_layout", + [](GluonOpBuilder &self, unsigned version, + std::vector &warpsPerCta, + std::vector &instrShape, bool transposed, + std::vector> &cgaBases, + std::vector &tilesPerWarp, + unsigned elementBitWidth) -> Attribute { + auto ctx = self.getContext(); + unsigned rank = warpsPerCta.size(); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return ttg::AMDMfmaEncodingAttr::get( + ctx, version, warpsPerCta, instrShape, transposed, cgaLayout, + tilesPerWarp, elementBitWidth); + }) + .def("get_amd_wmma_layout", + [](GluonOpBuilder &self, unsigned version, bool transposed, + std::vector> &warpBases, + std::vector> ®Bases, + std::vector> &cgaBases, + std::vector &instrShape, unsigned rank) -> Attribute { + auto ctx = self.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + auto ctaLayout = + tt::LinearLayout({{kReg, regBases}, {kWarp, warpBases}}, + tt::standardOutDimNames(ctx, rank)); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return ttg::AMDWmmaEncodingAttr::get( + ctx, version, ctaLayout, transposed, cgaLayout, instrShape); + }) + .def("get_padded_shared_layout", + [](GluonOpBuilder &self, std::vector &intervals, + std::vector &paddings, + std::vector> &offsetBases, + std::vector> &blockBases, + std::vector &shape) -> Attribute { + auto ctx = self.getContext(); + auto rank = shape.size(); + auto kOffset = mlir::StringAttr::get(ctx, "offset"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + auto ll = tt::LinearLayout( + {{kOffset, offsetBases}, {kBlock, blockBases}}, + tt::standardOutDimNames(ctx, rank)); + return ttg::PaddedSharedEncodingAttr::get(ctx, intervals, paddings, + std::move(ll)); + }) + .def("get_shared_linear_layout", + [](GluonOpBuilder &self, std::vector> &offsetBases, + std::vector> &blockBases, + unsigned alignment) -> Attribute { + auto ctx = self.getContext(); + auto kOffset = mlir::StringAttr::get(ctx, "offset"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + auto outDims = tt::standardOutDimNames(ctx, offsetBases[0].size()); + auto ll = tt::LinearLayout( + {{kOffset, offsetBases}, {kBlock, blockBases}}, outDims); + return self.getChecked( + ctx, std::move(ll), alignment); + }) + .def("get_nvmma_shared_layout", + [](GluonOpBuilder &self, unsigned swizzleByteWidth, + unsigned elementBitwidth, bool transposed, bool fp4Padded, + std::vector> &cgaBases, + unsigned rank) -> Attribute { + auto ctx = self.getContext(); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return self.getChecked( + ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded, + cgaLayout); + }) + .def("get_auto_layout", + [](GluonOpBuilder &self) -> Attribute { + return self.getChecked(self.getContext()); + }) + .def("get_coalesced_layout", + [](GluonOpBuilder &self) -> Attribute { + return self.getChecked( + self.getContext()); + }) + .def("get_swizzled_shared_layout", + [](GluonOpBuilder &self, int vec, int perPhase, int maxPhase, + std::vector &order, + std::vector> &cgaBases) -> Attribute { + auto ctx = self.getContext(); + unsigned rank = order.size(); + auto cgaLayout = buildCgaLayoutAttr(ctx, cgaBases, rank); + return self.getChecked( + ctx, vec, perPhase, maxPhase, order, cgaLayout); + }) + .def("get_tensor_memory_layout", + [](GluonOpBuilder &self, std::vector &block, + unsigned colStride, std::vector &ctaSplitNum, + bool twoCTAs) -> Attribute { + auto ctx = self.getContext(); + check(block.size() == 2, "expected a 2D block"); + check(ctaSplitNum.size() == 2, "expected 2D CTA dimensions"); + return self.getChecked( + ctx, block[0], block[1], colStride, ctaSplitNum[0], + ctaSplitNum[1], twoCTAs); + }) + .def("get_tensor_memory_scales_layout", + [](GluonOpBuilder &self, + std::vector &ctaSplitNum) -> Attribute { + auto ctx = self.getContext(); + check(ctaSplitNum.size() == 2, "expected 2D CTA dimensions"); + return self.getChecked( + ctx, ctaSplitNum[0], ctaSplitNum[1]); + }) + .def("get_shape_from_tensor", + [](GluonOpBuilder &self, Value tensor) -> std::vector { + auto ty = dyn_cast(tensor.getType()); + return ty.getShape(); + }) + .def("get_gluon_layout_from_tensor", + [](GluonOpBuilder &self, Value tensor) -> py::object { + auto ty = dyn_cast(tensor.getType()); + check(ty.getEncoding(), "expected a tensor with an encoding"); + return layoutToGluon(ty.getEncoding()); + }) + .def("get_gluon_layout_from_memdesc", + [](GluonOpBuilder &self, Value memdesc) -> py::object { + auto ty = dyn_cast(memdesc.getType()); + check(ty.getEncoding(), "expected a memdesc with an encoding"); + return layoutToGluon(ty.getEncoding()); + }) + .def("get_tensor_descriptor_layout_type", + [](GluonOpBuilder &self, Type blockType, bool isSigned, + Attribute layout) -> Type { + auto ctx = self.getContext(); + auto blockTy = cast(blockType); + auto blockTyLayout = blockTy.cloneWithEncoding(layout); + return triton::TensorDescType::get(ctx, blockTyLayout, isSigned); + }) + .def("is_convert_layout_trivial", + [](GluonOpBuilder &self, Type resultTy, Value value) -> bool { + auto dstTy = cast(resultTy); + return isConvertLayoutTrivial(dstTy, value); + }) + .def("create_histogram", + [](GluonOpBuilder &self, Value operand, int numBins, + std::optional mask, Attribute layout) -> Value { + auto *ctx = self.getContext(); + auto resultTy = + RankedTensorType::get({static_cast(numBins)}, + IntegerType::get(ctx, 32), layout); + if (!mask) { + return self.create(resultTy, operand); + } else { + return self.create(resultTy, operand, + *mask); + } + }) + .def("create_cat", + [](GluonOpBuilder &self, Value &lhs, Value &rhs, + Type retType) -> Value { + return self.create(retType, lhs, rhs); + }) + .def("create_fp4_to_fp", + [](GluonOpBuilder &self, Value src, Type elemType, + int axis) -> Value { + return self.create( + cast>(src), elemType, axis); + }) + .def("create_async_copy_global_to_local", + [](GluonOpBuilder &self, Value smem, Value pointer, Value mask, + Value other, tt::CacheModifier cacheModifier, + tt::EvictionPolicy evictionPolicy, bool isVolatile) { + self.create( + pointer, smem, mask, other, cacheModifier, evictionPolicy, + isVolatile); + }) +#if TRITON_ENABLE_AMD + .def("create_async_copy_local_to_global", + [](GluonOpBuilder &self, Value smem, Value pointer, Value mask, + tt::CacheModifier cacheModifier, + tt::EvictionPolicy evictionPolicy) { + self.create( + smem, pointer, mask, cacheModifier, evictionPolicy); + }) +#endif + .def("create_async_copy_mbarrier_arrive", + [](GluonOpBuilder &self, Value mbarrier, bool incrementCount) { + self.create(mbarrier, + !incrementCount); + }) + .def("create_async_commit_group", + [](GluonOpBuilder &self) { + ValueRange tokens; + self.create(tokens); + }) + .def("create_async_wait_group", + [](GluonOpBuilder &self, int num) { + ValueRange tokens; + self.create(tokens, num); + }) + .def("create_convert_layout", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_local_alloc", + [](GluonOpBuilder &self, Type resultTy) -> Value { + return self.create(resultTy); + }) + .def("create_local_alloc", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_local_store", + [](GluonOpBuilder &self, Value memDesc, Value value) { + self.create(value, memDesc); + }) + .def("create_local_load", + [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { + return self.create(resultTy, memDesc); + }) + .def("create_local_gather", + [](GluonOpBuilder &self, Type resultTy, Value memDesc, Value indices, + int32_t axis) -> Value { + auto ctx = self.getContext(); + auto i32Ty = IntegerType::get(ctx, 32); + auto axisAttr = IntegerAttr::get(i32Ty, axis); + return self.create(resultTy, memDesc, indices, + axisAttr); + }) + .def("create_local_scatter", + [](GluonOpBuilder &self, Value memDesc, Value values, Value indices, + int32_t axis) { + auto ctx = self.getContext(); + auto i32Ty = IntegerType::get(ctx, 32); + auto axisAttr = IntegerAttr::get(i32Ty, axis); + self.create(memDesc, values, indices, + axisAttr); + }) + .def("get_shared_bank_conflicts", + [](GluonOpBuilder &self, Attribute regLayoutAttr, + Attribute sharedLayoutAttr, std::vector &shape, + int bitwidth) -> int { + auto regLayout = ttg::toLinearLayout(shape, regLayoutAttr); + auto smemLayout = ttg::toLinearLayout(shape, sharedLayoutAttr); + return ttg::bankConflictsMemDesc(regLayout, smemLayout, bitwidth); + }) + .def("create_local_dealloc", + [](GluonOpBuilder &self, Value memDesc) -> Operation * { + return self.create(memDesc); + }) + + .def("create_memdesc_index", + [](GluonOpBuilder &self, Type resultType, Value src, + Value index) -> Value { + return self.create(resultType, src, index); + }) + .def("create_memdesc_subslice", + [](GluonOpBuilder &self, Type resultType, Value src, + std::vector &offsets) -> Value { + return self.create(resultType, src, + offsets); + }) + .def("create_memdesc_trans", + [](GluonOpBuilder &self, Value src, + std::vector &order) -> Value { + return self.create(src, order); + }) + .def("create_memdesc_reshape", + [](GluonOpBuilder &self, Value src, + std::vector &shape) -> Value { + return self.create(src, shape); + }) + .def("create_memdesc_reinterpret", + [](GluonOpBuilder &self, Type resultType, Value src) -> Value { + return self.create(resultType, src); + }) + .def("create_set_auto_layout", + [](GluonOpBuilder &self, Attribute layout, Value value) -> Value { + return self.create(layout, value); + }) + .def("create_split", + [](GluonOpBuilder &self, Value &a) -> py::tuple { + auto argTy = cast(a.getType()); + auto ctx = argTy.getContext(); + auto enc = ttg::SliceEncodingAttr::get( + ctx, argTy.getRank() - 1, + cast(argTy.getEncoding())); + auto resTy = + RankedTensorType::get(ArrayRef(argTy.getShape()).drop_back(), + argTy.getElementType(), enc); + auto op = self.create(TypeRange{resTy, resTy}, a); + return py::make_tuple(op->getResult(0), op->getResult(1)); + }) + .def("create_warpgroup_mma", + [](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc, + triton::InputPrecision precision = triton::InputPrecision::IEEE, + int maxNumImpreciseAcc = 0, bool isAsync = false) -> Value { + return self.create( + a, b, acc, useAcc, precision, maxNumImpreciseAcc, isAsync); + }) + .def("create_warpgroup_mma_wait", + [](GluonOpBuilder &self, std::vector &deps, int pendings) { + std::vector results; + auto wait = self.create(deps, pendings); + llvm::append_range(results, wait.getResults()); + return results; + }) + .def("create_tmem_alloc", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_tmem_alloc", + [](GluonOpBuilder &self, Type resultTy, py::none value) -> Value { + return self.create(resultTy, Value{}); + }) + .def("create_tmem_store", + [](GluonOpBuilder &self, Value memDesc, Value value, Value pred) { + self.create(memDesc, value, pred); + }) + .def( + "create_tmem_load", + [](GluonOpBuilder &self, Type resultTy, Value memDesc, + std::optional redOp, bool useAbs, + tt::PropagateNan propagateNan) -> py::object { + ttng::TMEMLoadReduceModifierAttr redOpAttr = nullptr; + BoolAttr absAttr = nullptr; + BoolAttr nanAttr = nullptr; + + if (redOp) { + redOpAttr = ttng::TMEMLoadReduceModifierAttr::get( + self.getContext(), redOp.value()); + if (useAbs) + absAttr = self.getBuilder().getBoolAttr(true); + if (propagateNan != tt::PropagateNan::NONE) + nanAttr = self.getBuilder().getBoolAttr(true); + } + + auto op = self.create( + resultTy, /*token=*/Type(), memDesc, /*dep=*/Value(), redOpAttr, + absAttr, nanAttr); + + if (redOp) { + Value result = op.getResult(); + Value red = op.getRed(); + auto redTy = cast(red.getType()); + py::object redLayout = layoutToGluon(redTy.getEncoding()); + return py::make_tuple(result, red, redLayout); + } + Value result = op.getResult(); + return py::cast(result); + }, + py::arg("resultTy"), py::arg("memDesc"), + py::arg("redOp") = py::none(), py::arg("useAbs") = false, + py::arg("propagateNan") = tt::PropagateNan::NONE) + .def("create_tmem_copy", + [](GluonOpBuilder &self, Value src, Value dst) { + self.create(src, dst, /*barrier=*/Value()); + }) + .def("create_tmem_subslice", + [](GluonOpBuilder &self, Type resultTy, Value memDesc, + int N) -> Value { + return self.create(resultTy, memDesc, N); + }) + .def("create_mbarrier_init", + [](GluonOpBuilder &self, Value memDesc, int count) { + self.create(memDesc, count); + }) + .def("create_mbarrier_inval", + [](GluonOpBuilder &self, Value memDesc) { + self.create(memDesc); + }) + .def("create_mbarrier_expect", + [](GluonOpBuilder &self, Value memDesc, int bytes, Value pred) { + self.create(memDesc, bytes, pred); + }) + .def("create_mbarrier_wait", + [](GluonOpBuilder &self, Value memDesc, Value phase, Value pred, + std::vector &deps) { + self.create(memDesc, phase, pred, deps); + }) + .def("create_mbarrier_arrive", + [](GluonOpBuilder &self, Value memDesc, int count, Value pred) { + self.create(memDesc, count, pred); + }) + .def("create_fence_mbarrier_init_release_cluster", + [](GluonOpBuilder &self) { + self.create(); + }) + .def("create_cluster_arrive", + [](GluonOpBuilder &self, bool relaxed) { + self.create(relaxed); + }) + .def("create_cluster_wait", + [](GluonOpBuilder &self) { self.create(); }) + .def("create_tcgen05_mma", + [](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc, + Value pred, std::vector &mbarriers, + std::vector &mbarrier_preds, bool two_ctas, + bool multicast) { + Value accDep; + auto tokType = self.getBuilder().getType(); + self.create(tokType, a, b, acc, accDep, useAcc, + pred, two_ctas, multicast, + mbarriers, mbarrier_preds); + }) + .def("create_tcgen05_mma_scaled", + [](GluonOpBuilder &self, Value a, Value b, Value acc, Value aScale, + Value bScale, tt::ScaleDotElemType aType, + tt::ScaleDotElemType bType, Value useAcc, Value pred, + std::vector &mbarriers, + std::vector &mbarrier_preds) { + Value accDep; + auto tokType = self.getBuilder().getType(); + self.create( + tokType, a, b, acc, accDep, aScale, bScale, aType, bType, + useAcc, pred, mbarriers, mbarrier_preds); + }) + .def("create_tcgen05_commit", + [](GluonOpBuilder &self, Value &barrier, Value &pred, + std::vector &descs) { + self.create(barrier, pred, descs); + }) + + .def("create_async_tma_copy_global_to_local", + [](GluonOpBuilder &self, Value descPtr, std::vector &coord, + Value barrier, Value result, Value pred, bool multicast) { + self.create( + descPtr, coord, barrier, result, pred, multicast); + }) + .def("create_async_tma_copy_local_to_global", + [](GluonOpBuilder &self, Value descPtr, std::vector &coord, + Value src) { + self.create(descPtr, coord, + src); + }) + .def("create_async_tma_reduce", + [](GluonOpBuilder &self, triton::DescriptorReduceKind kind, + Value descPtr, std::vector &coord, Value src) { + self.create(kind, descPtr, coord, src); + }) + .def("create_async_tma_store_wait", + [](GluonOpBuilder &self, int pendings) { + self.create(pendings); + }) + .def("create_async_tma_gather", + [](GluonOpBuilder &self, Value descPtr, Value xOffsets, + Value yOffset, Value barrier, Value result, Value pred) { + self.create(descPtr, xOffsets, yOffset, + barrier, result, pred); + }) + .def("create_async_tma_scatter", + [](GluonOpBuilder &self, Value descPtr, Value xOffsets, + Value yOffset, Value src) { + self.create(descPtr, xOffsets, yOffset, + src); + }) + .def("create_fence_async_shared", + [](GluonOpBuilder &self, bool bCluster) -> OpState { + return self.create(bCluster); + }) + .def("create_cluster_sync", + [](GluonOpBuilder &self) { + self.create(/*relaxed=*/false); + self.create(); + }) + + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, Type retTy) -> Value { + return self.create(retTy, arg); + }) + .def("create_warp_return", + [](GluonOpBuilder &self) -> Operation * { + return self.create(); + }) + .def("create_warp_yield", + [](GluonOpBuilder &self, std::vector &values) -> Operation * { + return self.create(values); + }) + .def("create_warp_specialize_partitions", + [](GluonOpBuilder &self, std::vector &explicitCaptures, + int numPartitions) -> Operation * { + return self.create( + explicitCaptures, numPartitions); + }) + .def("create_warp_specialize", + [](GluonOpBuilder &self, std::vector &resultTypes, + std::vector &partitionNumWarps) { + return self.create(resultTypes, + partitionNumWarps); + }) +#if TRITON_ENABLE_AMD + .def("create_buffer_load", + [](GluonOpBuilder &self, Type resultType, Value ptr, Value offsets, + Value mask, Value other, tt::CacheModifier cache) -> Value { + return self.create(resultType, ptr, offsets, + Value() /*stride*/, cache, + mask, other); + }) + .def("create_buffer_store", + [](GluonOpBuilder &self, Value storedValue, Value ptr, Value offsets, + Value mask, tt::CacheModifier cache) { + self.create(storedValue, ptr, offsets, + Value() /*stride*/, cache, mask); + }) + .def("create_buffer_atomic_rmw", + [](GluonOpBuilder &self, tt::RMWOp op, Value ptr, Value offsets, + Value value, tt::MemSemantic sem, tt::MemSyncScope scope, + Value mask) -> Value { + return self.create( + value.getType(), op, ptr, offsets, value, Value() /*stride*/, + sem, scope, mask); + }) + .def("create_buffer_load_to_local", + [](GluonOpBuilder &self, Value dest, Value ptr, Value offsets, + Value mask, Value other, Value stride, + tt::CacheModifier cacheModifier) { + self.create( + dest, ptr, offsets, mask, other, stride, cacheModifier); + }) + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Type resultTy, Value &base, + std::vector &shape, std::vector &strides, + tt::PaddingOption paddingOption) -> Value { + return self.create(resultTy, base, shape, + strides, paddingOption); + }) + .def("create_async_tdm_copy_global_to_local", + [](GluonOpBuilder &self, Value descPtr, std::vector &indices, + Value result, Value pred, Value barrier) { + self.create( + descPtr, indices, result, pred, barrier); + }) + .def("create_async_tdm_copy_local_to_global", + [](GluonOpBuilder &self, Value descPtr, std::vector &indices, + Value src, Value barrier) { + self.create(descPtr, indices, + src, barrier); + }) + .def("create_async_tdm_scatter", + [](GluonOpBuilder &self, Value descPtr, Value dstRowIndices, + Value dstColOffset, Value src, Value barrier) { + self.create(descPtr, dstRowIndices, + dstColOffset, src, barrier); + }) + .def("create_tdm_prefetch", + [](GluonOpBuilder &self, Value descPtr, std::vector &indices, + Value pred, bool speculative, bool returnOffsets) -> Value { + auto op = self.create( + descPtr, indices, pred, speculative, + returnOffsets ? UnitAttr::get(self.getContext()) : nullptr); + return returnOffsets ? op->getResult(0) : nullptr; + }) + .def("create_async_tdm_wait", + [](GluonOpBuilder &self, int num) { + ValueRange tokens; + self.create(tokens, num); + }) + .def("create_async_copy_lds_barrier_arrive", + [](GluonOpBuilder &self, Value mbarrier) { + self.create(mbarrier); + }) + .def("create_lds_barrier_init", + [](GluonOpBuilder &self, Value memDesc, int count) { + self.create(memDesc, count); + }) + .def("create_lds_barrier_wait", + [](GluonOpBuilder &self, Value memDesc, Value phase) { + self.create(memDesc, phase); + }) + .def("create_lds_barrier_arrive", + [](GluonOpBuilder &self, Value memDesc, int count) -> Value { + return self.create(memDesc, count); + }) + .def("create_amd_cluster_arrive", + [](GluonOpBuilder &self) { + self.create(); + }) + .def("create_amd_cluster_wait", + [](GluonOpBuilder &self) { + self.create(); + }) +#endif + .def("create_warp_pipeline_border", + [](GluonOpBuilder &self, const std::string &marker) { + auto border = self.create(0); + auto ctx = self.getContext(); + border->setAttr("triton.warp_pipeline.border", + StringAttr::get(ctx, marker)); + }); + + m.def( + "compute_tmem_reg_layout", + [](py::object elementTyObj, std::vector shape, + py::object layoutObj, unsigned numWarps, const std::string &atomName, + std::vector> cgaBases) -> py::object { + DialectRegistry registry; + registry.insert(); + MLIRContext context(MLIRContext::Threading::DISABLED); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + + GluonOpBuilder builder(&context); + auto builderObj = + py::cast(&builder, py::return_value_policy::reference); + + auto elementType = elementTyObj.attr("to_ir")(builderObj).cast(); + auto layoutAttr = + layoutObj.attr("_to_ir")(builderObj).cast(); + auto allocShape = shape; + + auto ctx = builder.getContext(); + unsigned rank = shape.size(); + auto memDescTy = builder.getChecked( + shape, elementType, layoutAttr, + ttng::TensorMemorySpaceAttr::get(ctx), + /*mutableMemory=*/true, allocShape); + auto ctaLayoutAttr = buildCgaLayoutAttr(ctx, cgaBases, rank); + + auto maybeAtom = + llvm::StringSwitch>(atomName) + .Case("32x32b", ttng::TMemAccessAtom::I32x32b) + .Case("16x64b", ttng::TMemAccessAtom::I16x64b) + .Case("16x128b", ttng::TMemAccessAtom::I16x128b) + .Case("16x256b", ttng::TMemAccessAtom::I16x256b) + .Case("16x32bx2", ttng::TMemAccessAtom::I16x32bx2) + .Default(std::nullopt); + if (!maybeAtom) + throw std::invalid_argument("unknown TMEM access atom: " + atomName); + auto atom = *maybeAtom; + if (atom == ttng::TMemAccessAtom::I16x32bx2) + throw std::invalid_argument( + "Atom 16x32bx2 is inferred implicitly and cannot be requested " + "explicitly"); + if (numWarps < 4 || !llvm::isPowerOf2_32(numWarps)) + throw std::invalid_argument( + "numWarps must be a power of two and >= 4"); + + auto layout = ttng::getDistributedLayoutForTmemLdSt( + memDescTy, atom, numWarps, ctaLayoutAttr); + if (!layout) + return py::none(); + + auto attr = ttg::LinearEncodingAttr::get(ctx, std::move(*layout)); + return layoutToGluon(attr); + }); + + m.def( + "make_cga_layout", + [](std::vector ctasPerCga, std::vector ctaSplitNum, + std::vector ctaOrder) -> std::vector> { + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(MLIRContext::Threading::DISABLED); + ctx.appendDialectRegistry(registry); + ctx.loadAllAvailableDialects(); + auto attr = ttg::CGAEncodingAttr::fromSplitParams( + &ctx, ctasPerCga, ctaSplitNum, ctaOrder); + return getCgaLayoutBases(attr); + }); + + m.def("get_amd_mfma_scale_layout", + [](unsigned opIdx, std::vector &shape, unsigned mfmaMDim, + std::vector &tilesPerWarp, + std::vector &warpsPerCTA) -> py::object { + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(MLIRContext::Threading::DISABLED); + ctx.appendDialectRegistry(registry); + ctx.loadAllAvailableDialects(); + + auto ll = ttg::chooseScaledMfmaScaleLayout( + &ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA); + auto attr = ttg::LinearEncodingAttr::get(&ctx, std::move(ll)); + return layoutToGluon(attr); + }); + + m.def("get_amd_wmma_scale_layout", + [](unsigned opIdx, std::vector &shape, unsigned wmmaMDim, + std::vector> ®Bases, + std::vector> &warpBases) -> py::object { + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(MLIRContext::Threading::DISABLED); + ctx.appendDialectRegistry(registry); + ctx.loadAllAvailableDialects(); + + auto rank = shape.size(); + auto kReg = mlir::StringAttr::get(&ctx, "register"); + auto kWarp = mlir::StringAttr::get(&ctx, "warp"); + auto ctaLayout = + tt::LinearLayout({{kReg, regBases}, {kWarp, warpBases}}, + tt::standardOutDimNames(&ctx, rank)); + auto ll = ttg::chooseScaledWmmaScaleLayout(&ctx, opIdx, shape, + wmmaMDim, ctaLayout); + auto attr = ttg::LinearEncodingAttr::get(&ctx, ll); + return layoutToGluon(attr); + }); + + m.def("get_layout_view", + [](py::object layout, std::vector shape, + bool useHwView) -> std::string { + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(MLIRContext::Threading::DISABLED); + ctx.appendDialectRegistry(registry); + ctx.loadAllAvailableDialects(); + + GluonOpBuilder builder(&ctx); + auto builderObj = + py::cast(&builder, py::return_value_policy::reference); + Attribute attr = layout.attr("_to_ir")(builderObj).cast(); + + if (isa(attr)) + throw py::value_error("AutoLayout cannot be visualized"); + if (isa(attr)) + throw py::value_error("CoalescedLayout cannot be visualized"); + if (isa(attr)) + throw py::value_error("PaddedSharedLayout cannot be visualized: " + "toLinearLayout not implemented"); + + auto ll = ttg::toLinearLayout(shape, attr); + if (isa(attr)) { + return ttg::getDistributedLayoutStr(ll, useHwView); + } else { + return ttg::getSharedLayoutStr(ll, useHwView); + } + }); + + py::class_(m, "WarpSpecializeOp", + py::module_local()) + .def("get_default_region", &ttg::WarpSpecializeOp::getDefaultRegion, + ret::reference) + .def("get_partition_op_holder", + &ttg::WarpSpecializeOp::getPartitionOpHolder, ret::reference) + .def("set_requested_registers", [](ttg::WarpSpecializeOp &self, + std::vector &requestedRegisters) { + self.setRequestedRegisters(requestedRegisters); + }); +} diff --git a/third_party/mthreads/python/src/interpreter.cc b/third_party/mthreads/python/src/interpreter.cc new file mode 100644 index 0000000000..747a0cc171 --- /dev/null +++ b/third_party/mthreads/python/src/interpreter.cc @@ -0,0 +1,740 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +struct npy_half { + uint16_t value; +}; + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +std::mutex atomic_op_guard; + +template +constexpr bool is_reinterpret_cast_to_atomic_safe = + std::is_trivially_copyable_v && + std::is_trivially_copyable_v> && + std::is_standard_layout_v && std::is_standard_layout_v> && + sizeof(T) == sizeof(std::atomic) && + alignof(T) == alignof(std::atomic); + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel}, + {MemSemantic::ACQUIRE, std::memory_order_acquire}, + {MemSemantic::RELEASE, std::memory_order_release}, + {MemSemantic::RELAXED, std::memory_order_relaxed}, +}; + +template +T atomic_cmp(T *ptr, T val, std::memory_order order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + + T old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_ptr = reinterpret_cast *>(ptr); + old_val = atomic_ptr->load(order); + while (cmp(old_val, val)) { + if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) { + break; + } + } + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *ptr; + if (cmp(old_val, val)) { + *ptr = val; + } + } + return old_val; +} + +template T atomic_fadd(T *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + T old_value; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + T new_value; + std::atomic *atomic_loc = reinterpret_cast *>(loc); + old_value = atomic_loc->load(order); + do { + new_value = old_value + value; + } while ( + !atomic_loc->compare_exchange_weak(old_value, new_value, order, order)); + } else { + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = old_value + value; + } + + return old_value; +} + +/** Create a value of type `To` from the bits of `from`. + * + * similar to `std::bit_cast` but compatible with C++17, + * should perform similar to `*reinterpret_cast(&from)` + * or through punning without expecting any undefined behaviors. + * + * Note: taken from + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 + * with simplification. + */ +template +inline To BitCast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), + "both data types must have the same size"); + + static_assert(std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + "both data types must be trivially copyable"); + + To to; + memcpy(&to, &from, sizeof(from)); + return to; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 +template +inline uint16_t FromFloatBits(uint32_t f) { + uint32_t f_exp, f_sig; + uint16_t h_sgn, h_exp, h_sig; + + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); + f_exp = (f & 0x7f800000u); + + /* Exponent overflow/NaN converts to signed inf/NaN */ + if (f_exp >= 0x47800000u) { + if (f_exp == 0x7f800000u) { + /* Inf or NaN */ + f_sig = (f & 0x007fffffu); + if (f_sig != 0) { + /* NaN - propagate the flag in the significand... */ + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); + /* ...but make sure it stays a NaN */ + if (ret == 0x7c00u) { + ret++; + } + return h_sgn + ret; + } else { + /* signed inf */ + return (uint16_t)(h_sgn + 0x7c00u); + } + } else { + if constexpr (gen_overflow) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error("overflow to signed inf"); + } + return (uint16_t)(h_sgn + 0x7c00u); + } + } + + /* Exponent underflow converts to a subnormal half or signed zero */ + if (f_exp <= 0x38000000u) { + /* + * Signed zeros, subnormal floats, and floats with small + * exponents all convert to signed zero half-floats. + */ + if (f_exp < 0x33000000u) { + if constexpr (gen_underflow) { + /* If f != 0, it underflowed to 0 */ + if ((f & 0x7fffffff) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + return h_sgn; + } + /* Make the subnormal significand */ + f_exp >>= 23; + f_sig = (0x00800000u + (f & 0x007fffffu)); + if constexpr (gen_underflow) { + /* If it's not exactly represented, it underflowed */ + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + /* + * Usually the significand is shifted by 13. For subnormals an + * additional shift needs to occur. This shift is one for the largest + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which + * offsets the new first bit. At most the shift can be 1+10 bits. + */ + f_sig >>= (113 - f_exp); + /* Handle rounding by adding 1 to the bit beyond half precision */ + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. However, the (113 - f_exp) + * shift can lose up to 11 bits, so the || checks them in the original. + * In all other cases, we can just add one. + */ + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp from zero to one and h_sig will be zero. + * This is the correct result. + */ + return (uint16_t)(h_sgn + h_sig); + } + + /* Regular case with no overflow or underflow */ + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); + /* Handle rounding by adding 1 to the bit beyond half precision */ + f_sig = (f & 0x007fffffu); + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. In all other cases, we do. + */ + if ((f_sig & 0x00003fffu) != 0x00001000u) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp by one and h_sig will be zero. This is the + * correct result. h_exp may increment to 15, at greatest, in + * which case the result overflows to a signed inf. + */ + if constexpr (gen_overflow) { + h_sig += h_exp; + if (h_sig == 0x7c00u) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error(""); + } + return h_sgn + h_sig; + } else { + return h_sgn + h_exp + h_sig; + } +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 +constexpr uint32_t ToFloatBits(uint16_t h) { + uint16_t h_exp = (h & 0x7c00u); + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: { // 0 or subnormal + uint16_t h_sig = (h & 0x03ffu); + // Signed zero + if (h_sig == 0) { + return f_sgn; + } + // Subnormal + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + } + case 0x7c00u: // inf or NaN + // All-ones exponent and a copy of the significand + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); + default: // normalized + // Just need to adjust the exponent and shift + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); + } +} + +npy_half npy_float_to_half(float f) { + return {FromFloatBits(BitCast(f))}; +} + +float npy_half_to_float(npy_half h) { + return BitCast(ToFloatBits(h.value)); +} + +template <> +npy_half atomic_fadd(npy_half *loc, npy_half value, + std::memory_order order) { + npy_half old_value; + + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = npy_float_to_half(npy_half_to_float(old_value) + + npy_half_to_float(value)); + + return old_value; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + std::memory_order order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + DType *ptr = static_cast(loc); + *(static_cast(ret) + i) = + applyAtMasked(ptr, *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc + value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc & value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc | value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc ^ value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = atomic_loc->exchange(value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = value; + } + return old_val; + } +}; + +template +void atomic_compare_exchange_strong(void *loc, void *expected, + const void *desired, size_t i, + std::memory_order order) { + T desired_val = *(static_cast(desired) + i); + T *expected_uint = static_cast(expected) + i; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); + atomic_loc->compare_exchange_strong(*expected_uint, desired_val, order, + order); + } else { + const std::lock_guard lock(atomic_op_guard); + T *atomic_loc = static_cast(loc); + if (*atomic_loc == *expected_uint) { + *atomic_loc = desired_val; + } else { + *expected_uint = *atomic_loc; + } + } +} + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + atomic_compare_exchange_strong(loc, expected, desired, i, order); + } else if (itemsize == 2) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 4) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 8) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else { + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + std::memory_order order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template <> template <> void OpCreator::create() { + if (!atomic_op && dtype.char_() == 'e') { // float16 + // workaround until https://github.com/pybind/pybind11/issues/4061 is + // implemented + atomic_op = std::make_unique>( + ptr, val, ret, mask, numel, order); + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, + std::memory_order order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/mthreads/python/src/ir.cc b/third_party/mthreads/python/src/ir.cc new file mode 100644 index 0000000000..bcd6ebc258 --- /dev/null +++ b/third_party/mthreads/python/src/ir.cc @@ -0,0 +1,2094 @@ +#include "ir.h" + +#include +#include +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Gluon/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +llvm::raw_fd_ostream &mlir_dumps() { + std::error_code EC; + static llvm::raw_fd_ostream S(::triton::tools::getStrEnv("MLIR_DUMP_PATH"), + EC, llvm::sys::fs::CD_CreateAlways); + assert(!EC); + return S; +} + +llvm::raw_ostream &mlir_dumps_or_dbgs() { + if (!::triton::tools::getStrEnv("MLIR_DUMP_PATH").empty()) { + return mlir_dumps(); + } else { + return llvm::dbgs(); + } +} + +// Function to parse a comma-separated string into a vector of C-style strings +llvm::SmallVector +parseCommaSeparatedValues(const std::string &input, + llvm::SmallVector &storage) { + llvm::SmallVector split; + llvm::SmallVector result; + StringRef(input.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + return result; +} + +// Run the pass manager under a source manager diagnostic handler, which +// enables emitted MLIR diagnostics to directly reference Python source +// code. This diagnostic handler supports filtering diagnostic info by +// severity levels. +struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler { + TritonSourceMgrDiagnosticHandler(MLIRContext *ctx, + DiagnosticSeverity minSeverity) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this, minSeverity](Diagnostic &diag) { + auto severity = diag.getSeverity(); + switch (severity) { + case DiagnosticSeverity::Error: + break; + case DiagnosticSeverity::Warning: + if (minSeverity == DiagnosticSeverity::Error) + return success(); + break; + case DiagnosticSeverity::Remark: + if (minSeverity == DiagnosticSeverity::Error || + minSeverity == DiagnosticSeverity::Warning) + return success(); + break; + case DiagnosticSeverity::Note: + // notes are handled somewhere else. + return failure(); + default: + llvm_unreachable("Unknown diagnostic severity"); + } + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; +}; + +TritonSourceMgrDiagnosticHandler +setupTritonDiagnosticHandler(MLIRContext *context) { + bool showOperations = false, showStacktraces = false, showRemarks = false, + showWarnings = false; + + if (auto enableDiagnostics = + triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS"); + !enableDiagnostics.empty()) { + llvm::SmallVector storage; + parseCommaSeparatedValues(enableDiagnostics, storage); + for (auto &str : storage) { + if (str == "warnings") { + showWarnings = true; + } else if (str == "remarks") { + showRemarks = true; + } else if (str == "stacktraces") { + showStacktraces = true; + } else if (str == "operations") { + showOperations = true; + } + // we show errors by default, so no need to set it + } + } + + DiagnosticSeverity minSeverity = + showWarnings ? DiagnosticSeverity::Warning : DiagnosticSeverity::Error; + minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity; + + context->printOpOnDiagnostic(showOperations); + context->printStackTraceOnDiagnostic(showStacktraces); + if (showStacktraces) { + context->disableMultithreading(); + } + + return TritonSourceMgrDiagnosticHandler(context, minSeverity); +} + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +// Allow dump a reproducer in the console on crash. +struct ConsoleReproducerStream : public mlir::ReproducerStream { + ~ConsoleReproducerStream() override {} + + StringRef description() override { + return "std::errs, please share the reproducer above with Triton project."; + } + raw_ostream &os() override { return llvm::errs(); } +}; + +ReproducerStreamFactory makeConsoleReproducer() { + return [](std::string &error) -> std::unique_ptr { + return std::make_unique(); + }; +} + +OpPrintingFlags getOpPrintingFlags() { + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + printingFlags.printNameLocAsPrefix(true); + return printingFlags; +} + +py::list getTensorDescMetadata(ModuleOp &mod) { + TritonSourceMgrDiagnosticHandler handler = + setupTritonDiagnosticHandler(mod.getContext()); + constexpr llvm::StringLiteral kHostTensorDescABIArgsAttr = + "musa.host_tensordesc_abi_args"; + + py::list result; + triton::FuncOp kernelFunc; + mod.walk([&](triton::FuncOp func) { + if (triton::isKernel(func)) { + kernelFunc = func; + return WalkResult::interrupt(); + } + return WalkResult::skip(); + }); + assert(kernelFunc); + + for (auto [i, arg] : llvm::enumerate(kernelFunc.getArguments())) { + auto descTy = dyn_cast(arg.getType()); + if (!descTy) + continue; + + auto blockType = descTy.getBlockType(); + auto encoding = blockType.getEncoding(); + + py::dict metadata; + auto rank = std::max(1, blockType.getRank()); + int64_t abiExpandedArgs = 1 + 2 * rank; + if (auto abiAttr = dyn_cast_or_null( + kernelFunc.getArgAttr(i, kHostTensorDescABIArgsAttr))) { + abiExpandedArgs = abiAttr.getInt(); + } + metadata["abi_expanded_args"] = abiExpandedArgs; + if (isa(encoding)) { + auto mmaEncoding = dyn_cast(encoding); + auto swizzle = ttng::getTMASwizzleMode(arg.getLoc(), descTy); + auto elemType = ttng::getTMAElementType(arg.getLoc(), descTy); + if (failed(swizzle) || failed(elemType)) + throw py::type_error("invalid TMA descriptor type"); + auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false); + metadata["swizzle"] = *swizzle; + metadata["elem_size"] = + descTy.getBlockType().getElementTypeBitWidth() / 8; + metadata["elem_type"] = *elemType; + metadata["block_size"] = + std::vector(blockSize.begin(), blockSize.end()); + metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded(); + } else { + auto blockShape = blockType.getShape(); + metadata["block_size"] = + std::vector(blockShape.begin(), blockShape.end()); + metadata["elem_bits"] = blockType.getElementTypeBitWidth(); + + if (auto paddedEnc = dyn_cast(encoding)) { + py::list intervalPaddingPairs; + for (auto [interval, padding] : llvm::zip_equal( + paddedEnc.getIntervals(), paddedEnc.getPaddings())) { + py::list pair; + pair.append(interval); + pair.append(padding); + intervalPaddingPairs.append(pair); + } + metadata["interval_padding_pairs"] = intervalPaddingPairs; + + auto blockShape = blockType.getShape(); + } + } + result.append(std::move(metadata)); + } + return result; +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "DESCRIPTOR_REDUCE_KIND", + py::module_local()) + .value("ADD", DescriptorReduceKind::ADD) + .value("AND", DescriptorReduceKind::AND) + .value("OR", DescriptorReduceKind::OR) + .value("XOR", DescriptorReduceKind::XOR) + .value("MAX", DescriptorReduceKind::MAX) + .value("MIN", DescriptorReduceKind::MIN) + .value("INC", DescriptorReduceKind::INC) + .value("DEC", DescriptorReduceKind::DEC); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .value("BF16x3", InputPrecision::BF16x3) + .value("BF16x6", InputPrecision::BF16x6) + .export_values(); + + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) + .value("FP16", ScaleDotElemType::FP16) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>([]() { + return std::make_unique(MLIRContext::Threading::DISABLED); + })) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + + if (std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + !filename.empty()) { + TritonPlugin TP(filename); + + std::vector dialectNames; + if (auto result = TP.getDialectHandles(dialectNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (unsigned i = 0; i < dialectNames.size(); ++i) { + const char *dialectName = dialectNames.data()[i]; + auto result = TP.getDialectPluginInfo(dialectName); + if (!result) + throw TP.err2exp(result.takeError()); + ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; + dialectPluginInfo.registerDialectRegistryCallbacks(®istry); + } + } + + registry.insert(); + mlir::LLVM::registerInlinerInterface(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__eq__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty != nullptr) && (*other_ty == self); + }) + .def("__ne__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty == nullptr) || (*other_ty != self); + }) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", + [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }) + .def("set_name", [](Location &self, std::string &name) { + mlir::StringAttr nameAttr = + mlir::StringAttr::get(self.getContext(), name); + mlir::NameLoc nameLoc = mlir::NameLoc::get(nameAttr, self); + self = dyn_cast(nameLoc); + }); + + py::class_(m, "value", py::module_local()) + .def(py::init<>()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("get_loc", &Value::getLoc) + .def("set_loc", &Value::setLoc) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", + [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }) + .def("set_loc", + [](Value &self, Location loc) { return self.setLoc(loc); }) + .def("get_loc", [](Value &self) { return self.getLoc(); }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()) + .def("get_loc", &BlockArgument::getLoc) + .def("set_loc", &BlockArgument::setLoc); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }) + .def("push_back", + [](Region &self, Block *block) { self.push_back(block); }) + .def("push_front", + [](Region &self, Block *block) { self.push_front(block); }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("add_argument_at", [](Block &self, Type ty, + Location loc) { self.addArgument(ty, loc); }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = getOpPrintingFlags(); + self->print(os, printingFlags); + return str; + }) + .def("str_nodebug", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + self->print(os); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", + [](OpState &self) -> bool { + TritonSourceMgrDiagnosticHandler handler = + setupTritonDiagnosticHandler(self.getContext()); + return succeeded(verify(self.getOperation())); + }) + .def("get_operation", [](OpState &self) { return self.getOperation(); }); + + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_int_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = getOpPrintingFlags(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (triton::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("get_tensordesc_metadata", getTensorDescMetadata) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + auto printingFlags = getOpPrintingFlags(); + if (failed(generateLocationsFromIR(fileName, self, printingFlags))) + throw std::runtime_error("Failed to create location snapshot"); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def("get_num_args", &FuncOp::getNumArguments) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "op_builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + .def("get_op_builder", &TritonOpBuilder::getBuilder, ret::reference) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def( + "get_unit_attr", + [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + .def("get_string_attr", + [](TritonOpBuilder &self, std::string value) -> Attribute { + return self.getBuilder().getStringAttr(value); + }) + .def("get_disable_loop_licm_attr", + [](TritonOpBuilder &self) -> Attribute { + auto licmAttr = + LLVM::LoopLICMAttr::get(self.getBuilder().getContext(), + self.getBuilder().getBoolAttr(true), + self.getBuilder().getBoolAttr(true)); + mlir::LLVM::LoopAnnotationAttr la = + mlir::LLVM::LoopAnnotationAttr::get( + self.getBuilder().getContext(), {}, {}, {}, {}, {}, + licmAttr, {}, {}, {}, {}, {}, {}, {}, {}, {}); + return la; + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + self.getBuilder().getI1Type(), v)); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI8Type(), v)); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI16Type(), v)); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI32Type(), v)); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI64Type(), v)); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI8Type(), v)); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI16Type(), v)); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI32Type(), v)); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + self.getBuilder().getI64Type(), v)); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + type, APFloat(type.getFloatSemantics(), std::to_string(v))); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + floatTy, APFloat(floatTy.getFloatSemantics(), 0)); + else if (auto intTy = dyn_cast(type)) + return self.create(intTy, 0); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(intTy, val); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, std::string name) { + auto nameAttr = StringAttr::get(self.getContext(), name); + auto loc = NameLoc::get(nameAttr); + self.setLastLoc(loc); + }) + .def("create_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) -> Location { + return mlir::FileLineColLoc::get(self.getContext(), fileName, line, + column); + }) + .def( + "create_name_loc", + [](TritonOpBuilder &self, std::string name, + std::optional childLoc) -> Location { + auto nameAttr = StringAttr::get(self.getContext(), name); + if (childLoc) + return NameLoc::get(nameAttr, *childLoc); + return NameLoc::get(nameAttr); + }, + py::arg("name"), py::arg("child_loc") = py::none()) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, Type retTy, int start, int end) -> Value { + return self.create(retTy, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_tensor_descriptor_type", + [](TritonOpBuilder &self, Type blockTy, bool isSigned) -> Type { + auto ctx = self.getContext(); + return triton::TensorDescType::get( + ctx, cast(blockTy), isSigned); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value desc, std::vector &indices, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getSignlessBlockType(); + return self.create( + resTy, desc, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_gather", + [](TritonOpBuilder &self, Value desc, Value x_indices, Value y_index, + Type type) -> Value { + return self.create(type, desc, x_indices, + y_index); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value desc, Value value, + std::vector &indices) -> void { + self.create(desc, value, indices); + }) + .def("create_descriptor_reduce", + [](TritonOpBuilder &self, DescriptorReduceKind kind, Value desc, + Value value, std::vector &indices) -> void { + self.create(kind, desc, value, indices); + }) + .def("create_descriptor_scatter", + [](TritonOpBuilder &self, Value desc, Value value, Value x_indices, + Value y_index) -> void { + self.create(desc, x_indices, y_index, value); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + return self.create(shape, arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + return self.create(arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create(lhsType.clone(shape), lhs, rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, std::vector &order) + -> Value { return self.create(arg, order); }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold(argType.clone(shape), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Type &retTy, Value &arg) -> Value { + return self.createOrFold(retTy, arg); + }) + .def("create_unsplat", + [](TritonOpBuilder &self, Value &arg) -> Value { + return self.createOrFold(arg); + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = srcTensorType.clone(dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = srcTensorType.clone(dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, + std::optional &lhs_scale, + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, bool fast_math, bool lhs_k_pack, + bool rhs_k_pack, mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()), + rhs_scale.value_or(Value()), lhs_format, rhs_format, fast_math, + lhs_k_pack, rhs_k_pack); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_map_elementwise", + [](TritonOpBuilder &self, std::vector inputs, + std::vector returnTys, int pack) -> OpState { + return self.create(returnTys, inputs, pack); + }) + .def("create_map_elementwise_ret", + [](TritonOpBuilder &self, std::vector returnVals) -> OpState { + return self.create(returnVals); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins, + std::optional mask) -> Value { + if (!mask) { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + } else { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand, *mask); + } + }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { + self.create(triton::gpu::AddrSpace::All); + }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &tensorShape, + bool isSignedInteger, PaddingOption paddingOption) -> Value { + return self.create(base, shape, strides, + tensorShape, isSignedInteger, + paddingOption); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) -> bool { + auto *context = self.getContext(); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + bool isEnvValueBool = + triton::tools::isEnvValueBool(funcToDump).has_value(); + if (!funcToDump.empty() && !isEnvValueBool) + haveDump = true; + } + if (haveDump) { + context->disableMultithreading(); + auto printingFlags = getOpPrintingFlags(); + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, mlir_dumps_or_dbgs(), + printingFlags); + } + return haveDump; + }) + .def("get_pipeline_str", + [](PassManager &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.printAsTextualPipeline(os); + return str; + }) + .def( + "run", + [](PassManager &self, ModuleOp &mod, std::string repro_pipeline_tag) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto *context = mod.getContext(); + if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) + context->disableMultithreading(); + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + if (reproducerPath != "-") { + std::string repro_suffix = + "." + repro_pipeline_tag + ".repro.mlir"; + reproducerPath += repro_suffix; + } + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + // Save a reproducer for the current pass manager invocation + // immediately. + makeReproducer(anchorName, passes, op, reproducerPath); + // But if the pass manager crashes, attempt to generate a local + // reproducer instead. + context->disableMultithreading(); + self.enableCrashReproducerGeneration(reproducerPath, + /*genLocalReproducer=*/true); + } else { + self.enableCrashReproducerGeneration(makeConsoleReproducer()); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = + triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector storage; + llvm::SmallVector debugTypes = + parseCommaSeparatedValues(debugOnly, storage); + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + TritonSourceMgrDiagnosticHandler diagHandler = + setupTritonDiagnosticHandler(context); + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }, + py::call_guard()); +} + +bool str_eq_ignore_case(const char *s1, const char *s2, int n) { + for (int i = 0; i < n; ++i) { + if (tolower(s1[i]) != s2[i]) + return false; + } + return true; +} + +int strlen_max(const char *str, int max) { + for (int i = 0; i <= max; ++i) { + if (str[i] == '\0') { + return i; + } + } + return 0; +} + +bool is_truthy(char *str) { + int len = strlen_max(str, 4); + switch (len) { + case 1: + return str[0] == '1' || tolower(str[0]) == 'y'; + case 2: + return str_eq_ignore_case(str, "on", len); + case 3: + return str_eq_ignore_case(str, "yes", len); + case 4: + return str_eq_ignore_case(str, "true", len); + default: + return false; + } +} + +PyObject *py_getenv(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + if (!(nargs == 1 || nargs == 2)) { + PyErr_SetString(PyExc_TypeError, "getenv expected 1 or 2 arguments"); + return NULL; + } + PyObject *name = args[0]; + PyObject *default_val = nargs == 2 ? args[1] : Py_None; + if (!PyUnicode_CheckExact(name)) { + PyErr_SetString(PyExc_TypeError, "name must be a string"); + return NULL; + } + char *env_val = getenv(PyUnicode_AsUTF8(name)); + if (!env_val) { + Py_INCREF(default_val); + return default_val; + } + return PyUnicode_FromString(env_val); +} + +PyObject *py_getenv_bool(PyObject *self, PyObject *const *args, + Py_ssize_t nargs) { + if (nargs != 2) { + PyErr_SetString(PyExc_TypeError, "getenv_bool expected 2 arguments"); + return NULL; + } + PyObject *name = args[0]; + PyObject *default_val = args[1]; + if (!PyUnicode_CheckExact(name)) { + PyErr_SetString(PyExc_TypeError, "name must be a string"); + return NULL; + } + char *env_val = getenv(PyUnicode_AsUTF8(name)); + PyObject *res = default_val; + if (env_val) { + res = is_truthy(env_val) ? Py_True : Py_False; + } + Py_INCREF(res); + return res; +} + +static PyMethodDef ModuleMethods[] = { + {"getenv", (PyCFunction)py_getenv, METH_FASTCALL, NULL}, + {"getenv_bool", (PyCFunction)py_getenv_bool, METH_FASTCALL, NULL}, + {NULL, NULL, 0, NULL} // sentinel +}; + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); + PyModule_AddFunctions(m.ptr(), ModuleMethods); +} diff --git a/third_party/mthreads/python/src/ir.h b/third_party/mthreads/python/src/ir.h new file mode 100644 index 0000000000..499dd9e8a9 --- /dev/null +++ b/third_party/mthreads/python/src/ir.h @@ -0,0 +1,100 @@ +#pragma once +#include "mlir/IR/Builders.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(mlir::MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + mlir::OpBuilder &getBuilder() { return *builder; } + mlir::MLIRContext *getContext() { return builder->getContext(); } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(mlir::Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(mlir::FileLineColLoc::get(context, fileName, line, column)); + } + + mlir::Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(getLocForBlock(&block)); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(getLocForBlock(&block)); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(mlir::Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) { + setLastLoc(builder->getUnknownLoc()); + if (pt.isSet()) { + if (pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(getLocForBlock(pt.getBlock())); + } + + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return OpTy::create(*builder, loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), + mlir::Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = + !mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); + + mlir::Location getLocForBlock(mlir::Block *block) { + if (auto parentOp = block->getParentOp()) + return parentOp->getLoc(); + return builder->getUnknownLoc(); + } +}; diff --git a/third_party/mthreads/python/src/linear_layout.cc b/third_party/mthreads/python/src/linear_layout.cc new file mode 100644 index 0000000000..21ad101257 --- /dev/null +++ b/third_party/mthreads/python/src/linear_layout.cc @@ -0,0 +1,223 @@ +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" +#include "llvm/ADT/STLExtras.h" +#include +#include +#include + +namespace py = pybind11; +using LinearLayout = mlir::triton::LinearLayout; + +namespace { + +mlir::MLIRContext *getLinearLayoutContext() { + static PyObject *ctxObject = []() { + py::module irMod = py::module::import("triton._C.libtriton.ir"); + // Keep the Python object alive for the life of the process without running + // its destructor during interpreter shutdown (avoids segfaults). + py::object ctx = irMod.attr("context")(); + return ctx.release().ptr(); + }(); + return py::cast(py::handle(ctxObject)); +} + +} // namespace + +void init_linear_layout(py::module &&m) { + py::class_(m, "LinearLayout", py::module_local(false)) + .def(py::init<>()) + .def_static( + "identity_1d", + [](int32_t size, std::string inDim, std::string outDim) { + auto *ctx = getLinearLayoutContext(); + return LinearLayout::identity1D(size, + mlir::StringAttr::get(ctx, inDim), + mlir::StringAttr::get(ctx, outDim)); + }, + py::arg("size"), py::arg("inDim"), py::arg("outDim")) + .def_static( + "strided_1d", + [](int32_t size, int32_t stride, std::string inDim, + std::string outDim) { + auto *ctx = getLinearLayoutContext(); + return LinearLayout::strided1D(size, stride, + mlir::StringAttr::get(ctx, inDim), + mlir::StringAttr::get(ctx, outDim)); + }, + py::arg("size"), py::arg("stride"), py::arg("inDim"), + py::arg("outDim")) + .def_static( + "zeros_1d", + [](int32_t size, std::string inDim, std::string outDim, + int32_t outDimSize) { + auto *ctx = getLinearLayoutContext(); + return LinearLayout::zeros1D( + size, mlir::StringAttr::get(ctx, inDim), + mlir::StringAttr::get(ctx, outDim), outDimSize); + }, + py::arg("size"), py::arg("inDim"), py::arg("outDim"), + py::arg("outDimSize") = 1) + .def_static( + "from_bases", + [](const std::vector>>> &bases, + const std::vector &outDimNames, + std::optional> outDimSizes, + bool requireSurjective) { + auto *ctx = getLinearLayoutContext(); + + std::vector< + std::pair>>> + convertedBases; + convertedBases.reserve(bases.size()); + for (const auto &entry : bases) { + std::vector> converted; + converted.reserve(entry.second.size()); + for (const auto &vec : entry.second) + converted.emplace_back(vec.begin(), vec.end()); + convertedBases.emplace_back( + mlir::StringAttr::get(ctx, entry.first), + std::move(converted)); + } + + if (outDimSizes) { + if (outDimSizes->size() != outDimNames.size()) + throw std::invalid_argument("out_dim_names and out_dim_sizes " + "must have the same length"); + std::vector> outDims; + outDims.reserve(outDimNames.size()); + for (auto it : llvm::enumerate(outDimNames)) + outDims.emplace_back(mlir::StringAttr::get(ctx, it.value()), + (*outDimSizes)[it.index()]); + return LinearLayout(convertedBases, outDims, requireSurjective); + } + + if (!requireSurjective) + throw std::invalid_argument("out_dim_sizes must be provided when " + "require_surjective is false"); + + std::vector convertedNames; + convertedNames.reserve(outDimNames.size()); + for (const auto &name : outDimNames) + convertedNames.push_back(mlir::StringAttr::get(ctx, name)); + return LinearLayout(convertedBases, convertedNames); + }, + py::arg("bases"), py::arg("out_dim_names"), + py::arg("out_dim_sizes") = py::none(), + py::arg("require_surjective") = true) + .def("compose", &LinearLayout::compose) + .def("invert_and_compose", &LinearLayout::invertAndCompose) + .def("invert", &LinearLayout::invert) + .def("pseudoinvert", &LinearLayout::pseudoinvert) + .def("is_surjective", &LinearLayout::isSurjective) + .def("is_injective", &LinearLayout::isInjective) + .def("is_invertible", &LinearLayout::isInvertible) + .def("get_in_dim_names", + [](const LinearLayout &self) { + std::vector dims; + dims.reserve(self.getNumInDims()); + for (mlir::StringAttr dim : self.getInDimNames()) + dims.push_back(dim.str()); + return dims; + }) + .def("get_out_dim_names", + [](const LinearLayout &self) { + std::vector dims; + dims.reserve(self.getNumOutDims()); + for (mlir::StringAttr dim : self.getOutDimNames()) + dims.push_back(dim.str()); + return dims; + }) + .def_property_readonly( + "bases", + [](const LinearLayout &self) { + auto bases = self.getBases(); + pybind11::list result; + for (const auto &it : bases) { + pybind11::list dimBases; + for (const auto &vec : it.second) + dimBases.append(pybind11::cast( + std::vector(vec.begin(), vec.end()))); + result.append(pybind11::make_tuple(it.first.str(), dimBases)); + } + return result; + }) + .def_property_readonly( + "out_dims", + [](const LinearLayout &self) { + pybind11::list result; + for (const auto &it : self.getOutDims()) { + result.append(pybind11::make_tuple(it.first.str(), it.second)); + } + return result; + }) + .def_property_readonly("num_in_dims", &LinearLayout::getNumInDims) + .def_property_readonly("num_out_dims", &LinearLayout::getNumOutDims) + .def("__mul__", [](const LinearLayout &lhs, + const LinearLayout &rhs) { return lhs * rhs; }) + .def( + "__imul__", + [](LinearLayout &lhs, const LinearLayout &rhs) -> LinearLayout & { + lhs *= rhs; + return lhs; + }, + py::return_value_policy::reference_internal) + .def("__eq__", [](const LinearLayout &lhs, + const LinearLayout &rhs) { return lhs == rhs; }) + .def("__ne__", [](const LinearLayout &lhs, + const LinearLayout &rhs) { return lhs != rhs; }) + .def("__repr__", [](const LinearLayout &self) { return self.toString(); }) + .def("__str__", [](const LinearLayout &self) { return self.toString(); }) + .def("get_shared_view", + [](const LinearLayout &self, bool useHWPointOfView) { + return mlir::triton::gpu::getSharedLayoutStr( + const_cast(self), useHWPointOfView); + }) + .def("get_distributed_view", + [](const LinearLayout &self, bool useHWPointOfView) { + return mlir::triton::gpu::getDistributedLayoutStr( + const_cast(self), useHWPointOfView); + }) + .def( + "apply", + [](const LinearLayout &self, py::dict inputsDict) { + std::vector> inputs; + inputs.reserve(inputsDict.size()); + for (auto item : inputsDict) { + inputs.emplace_back(py::cast(item.first), + py::cast(item.second)); + } + auto *ctx = getLinearLayoutContext(); + std::vector> converted; + converted.reserve(inputs.size()); + for (const auto &it : inputs) { + converted.emplace_back(mlir::StringAttr::get(ctx, it.first), + it.second); + } + auto outputs = self.apply(converted); + py::dict result; + for (const auto &out : outputs) { + result[py::str(out.first.str())] = out.second; + } + return result; + }, + py::arg("inputs")) + .def("get_matrix_view", [](const LinearLayout &self) { + std::unique_ptr matrix = mlir::triton::getMatrix(self); + auto nRows = self.getTotalOutDimSizeLog2(); + auto nCols = self.getTotalInDimSizeLog2(); + std::vector> result(nRows, std::vector(nCols)); + for (size_t i = 0; i < nRows; ++i) { + for (size_t j = 0; j < nCols; ++j) { + result[i][j] = (matrix[i] >> j) & 1; + } + } + return result; + }); +} diff --git a/third_party/mthreads/python/src/llvm.cc b/third_party/mthreads/python/src/llvm.cc new file mode 100644 index 0000000000..110a725be7 --- /dev/null +++ b/third_party/mthreads/python/src/llvm.cc @@ -0,0 +1,943 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizer.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +namespace { +void initializeTritonLLVMTargets() { + // Keep host target registration for host-side LLVM utilities. + (void)llvm::InitializeNativeTarget(); + (void)llvm::InitializeNativeTargetAsmParser(); + (void)llvm::InitializeNativeTargetAsmPrinter(); + +#if TRITON_ENABLE_NVIDIA + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); +#endif + +#if TRITON_ENABLE_AMD + LLVMInitializeAMDGPUTargetInfo(); + LLVMInitializeAMDGPUTarget(); + LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmParser(); + LLVMInitializeAMDGPUAsmPrinter(); +#endif +} +} // namespace + +// Set an LLVM command-line option using addOccurrence (simulates command-line) +// and return its original value. Using addOccurrence instead of setValue is +// necessary because some LLVM passes (like schedulers) check whether the option +// was explicitly set on the command line. +template T setLLVMOption(const std::string &name, T value); + +template <> bool setLLVMOption(const std::string &name, bool value) { + auto options = llvm::cl::getRegisteredOptions(); + auto it = options.find(name); + if (it == options.end()) + return false; + auto *opt = static_cast *>(it->second); + bool original = opt->getValue(); + // Use addOccurrence to mark the option as explicitly set on command line. + // This is important for options like enable-misched where LLVM checks + // getNumOccurrences() to determine if the option was explicitly set. + // See: llvm/lib/CodeGen/MachineScheduler.cpp - + // enableMachineSchedDefaultSched() checks + // "EnableMachineSched.getNumOccurrences()" to decide behavior. + it->second->addOccurrence(1, name, value ? "true" : "false"); + return original; +} + +template <> +std::string setLLVMOption(const std::string &name, + std::string value) { + auto options = llvm::cl::getRegisteredOptions(); + auto it = options.find(name); + if (it == options.end()) + return ""; + auto *opt = static_cast *>(it->second); + std::string original = opt->getValue(); + it->second->addOccurrence(1, name, value); + return original; +} + +// Restore an LLVM command-line option to a previous value +template void restoreLLVMOption(const std::string &name, T value); + +template <> void restoreLLVMOption(const std::string &name, bool value) { + auto options = llvm::cl::getRegisteredOptions(); + auto it = options.find(name); + if (it != options.end()) { + auto *opt = static_cast *>(it->second); + opt->setValue(value); + } +} + +template <> +void restoreLLVMOption(const std::string &name, + std::string value) { + auto options = llvm::cl::getRegisteredOptions(); + auto it = options.find(name); + if (it != options.end()) { + it->second->addOccurrence(1, name, value); + } +} + +// RAII guard that sets an LLVM option and restores it on destruction +template class ScopedLLVMOption { + std::string name; + T originalValue; + +public: + ScopedLLVMOption(const std::string &n, T newValue) : name(n) { + originalValue = setLLVMOption(name, newValue); + } + ~ScopedLLVMOption() { restoreLLVMOption(name, originalValue); } + + // Non-copyable + ScopedLLVMOption(const ScopedLLVMOption &) = delete; + ScopedLLVMOption &operator=(const ScopedLLVMOption &) = delete; +}; + +std::unique_ptr +createTargetMachine(llvm::Module *module, std::string proc, + bool enable_fp_fusion, const std::string &features) { + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetOptions opt; + bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + std::unique_ptr machine{target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + return machine; +} + +void dumpSchedulingDAG(llvm::Module &module, const std::string &triple, + const std::string &proc, const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, const std::string &dumpFileId) { + using namespace mlir; + + // Check if we should dump sched DAG + std::string dumpMirBase = triton::tools::getStrEnv("TRITON_DUMP_MIR"); + bool dumpMir = !dumpMirBase.empty(); + if (!dumpMir) { + return; + } + + // Apply flags + for (const std::string &flag : flags) { + setLLVMOption(flag, true); + } + + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (const auto &flag : split) { + setLLVMOption(flag.str(), true); + } + } + } + + std::string dumpFilename = dumpMirBase + "/" + dumpFileId + ".txt"; + + // Use RAII to set options and restore them when scope exits + ScopedLLVMOption stopAfterGuard("stop-after", + "machine-scheduler"); + ScopedLLVMOption mischedPrintGuard("misched-print-dags", true); + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + pm.run(module); + + // create machine + module.setTargetTriple(Triple(triple)); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + // set data layout + module.setDataLayout(machine->createDataLayout()); + + // Save original stderr file descriptor + int saved_stderr_fd = dup(fileno(stderr)); + + // Redirect stderr to append to dump file + FILE *redirected = freopen(dumpFilename.c_str(), "a", stderr); + if (!redirected) { + llvm::errs() << "Warning: Failed to redirect stderr to " << dumpFilename + << "\n"; + } + + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pass; + // emit + machine->addPassesToEmitFile(pass, pstream, nullptr, + llvm::CodeGenFileType::AssemblyFile); + pass.run(module); + } + + // Restore stderr + fflush(stderr); + if (saved_stderr_fd != -1) { + dup2(saved_stderr_fd, fileno(stderr)); + close(saved_stderr_fd); + clearerr(stderr); + } + + llvm::errs() << "DAG dumped to: " << dumpFilename << "\n"; + // LLVM options are automatically restored when scope exits via RAII +} + +std::string +translateLLVMIRToMIR(llvm::Module &module, const std::string &triple, + const std::string &proc, const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, const std::string &dumpFileId) { + using namespace mlir; + + // Check if we should dump MIR + std::string dumpMirBase = triton::tools::getStrEnv("TRITON_DUMP_MIR"); + bool dumpMir = !dumpMirBase.empty(); + if (!dumpMir) { + return ""; + } + + llvm::StripDebugInfo(module); + + // Apply flags + for (const std::string &flag : flags) { + setLLVMOption(flag, true); + } + + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (const auto &flag : split) { + setLLVMOption(flag.str(), true); + } + } + } + + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + setLLVMOption("print-after-all", true); + } + + // Use RAII to set stop-before and restore it when scope exits + ScopedLLVMOption stopBeforeGuard("stop-before", + "machine-scheduler"); + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + pm.run(module); + + // create machine + module.setTargetTriple(Triple(triple)); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + // set data layout + module.setDataLayout(machine->createDataLayout()); + + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pass; + // emit + machine->addPassesToEmitFile(pass, pstream, nullptr, + llvm::CodeGenFileType::AssemblyFile); + pass.run(module); + } + + std::string dumpFilename = dumpMirBase + "/" + dumpFileId + ".txt"; + { + std::error_code EC; + llvm::raw_fd_ostream outFile(dumpFilename, EC, llvm::sys::fs::OF_None); + if (EC) { + llvm::errs() << "Error opening file " << dumpFilename << ": " + << EC.message() << "\n"; + } else { + outFile << result; + outFile << "---"; + outFile << "\n========== SCHEDULING DAG ==========\n"; + } + llvm::errs() << "MIR dumped to: " << dumpFilename << "\n"; + } + + return result; +} + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + + // Apply flags + for (const std::string &flag : flags) { + setLLVMOption(flag, true); + } + + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + setLLVMOption("print-after-all", true); + } + + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (const auto &flag : split) { + setLLVMOption(flag.str(), true); + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + + // create machine + module.setTargetTriple(Triple(triple)); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +std::string translateMIRToASM(const std::string &mirPath, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + + // We need to start before machine-scheduler and disable it instead of simply + // start after it because machine-scheduler is used as anchor point to insert + // some passes. Starting after machine-scheduler would also not insert these + // passes to the pipeline. + // Use RAII to set options and restore them when scope exits + ScopedLLVMOption startBeforeGuard("start-before", + "machine-scheduler"); + ScopedLLVMOption enableMISchedGuard("enable-misched", false); + ScopedLLVMOption enablePostMISchedGuard("enable-post-misched", false); + + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + setLLVMOption("print-after-all", true); + } + + // Apply other flags + for (const std::string &flag : flags) { + setLLVMOption(flag, true); + } + + // Parse MIR into LLVM Module + llvm::LLVMContext context; + llvm::SMDiagnostic error; + + // Load MIR file into memory + llvm::ErrorOr> buffer = + llvm::MemoryBuffer::getFile(mirPath); + + if (!buffer) { + llvm::report_fatal_error(llvm::Twine("failed to open MIR file: ") + + mirPath + " " + buffer.getError().message()); + } + + std::unique_ptr mirParser = + llvm::createMIRParser(std::move(buffer.get()), context); + + if (!mirParser) { + llvm::report_fatal_error("failed to create MIR parser"); + } + + std::unique_ptr module = mirParser->parseIRModule(); + if (!module) { + llvm::report_fatal_error("failed to parse MIR IR module"); + } + + // Setup target machine + module->setTargetTriple(Triple(triple)); + auto machine = + createTargetMachine(module.get(), proc, enable_fp_fusion, features); + module->setDataLayout(machine->createDataLayout()); + + // Create PassManager + llvm::legacy::PassManager pass; + + // IMPORTANT: Add ScopedNoAliasAAWrapperPass to ensure alias analysis + // understands !alias.scope and !noalias metadata during machine scheduling. + // + // When loading MIR directly (swap path), we skip the normal IR optimization + // passes that would register ScopedNoAliasAA. Without this, the machine + // scheduler cannot prove that async buffer loads (BUFFER_LOAD_DWORDX4_LDS) + // don't alias with LDS reads (DS_READ), resulting in unnecessary memory + // dependencies and ~30% performance regression. + pass.add(llvm::createScopedNoAliasAAWrapperPass()); + + // Emit code from MIR + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + + // Create MachineModuleInfoWrapperPass FIRST + llvm::MachineModuleInfoWrapperPass *MMIWP = + new llvm::MachineModuleInfoWrapperPass(machine.get()); + + // This will run the remaining machine passes and emit assembly/object + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType, + /*NoVerify*/ false, MMIWP); + + // Now parse machine functions + if (mirParser->parseMachineFunctions(*module, MMIWP->getMMI())) { + llvm::report_fatal_error("Failed to parse machine functions from MIR"); + } + + pass.run(*module); + } + + // LLVM options are automatically restored when scope exits via RAII + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + py::class_(m, "source_mgr", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + .def("remove_fn_attr", [](llvm::Function *fn, + std::string &name) { fn->removeFnAttr(name); }) + .def("add_fn_asan_attr", + [](llvm::Function *fn) { + fn->addFnAttr(llvm::Attribute::SanitizeAddress); + }) + .def("add_fn_target_feature", + [](llvm::Function *fn, std::string &val) { + fn->addFnAttr("target-features", val); + }) + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + std::unique_ptr llvmMod = + mlir::translateModuleToLLVMIR(mod, ctx); + if (!llvmMod) { + throw std::runtime_error("failed to translate module to LLVM IR"); + } + return llvmMod; + }, + py::keep_alive<0, 2>(), py::call_guard()); + + m.def("attach_datalayout", [](llvm::Module *mod, const std::string triple, + const std::string proc, + const std::string features) { + std::string error; + llvm::Triple targetTriple(triple); + auto target = llvm::TargetRegistry::lookupTarget(targetTriple, error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + llvm::TargetOptions opt; + // Target machine is only used to create the data layout. + std::unique_ptr machine{target->createTargetMachine( + targetTriple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt, + llvm::CodeGenOptLevel::None)}; + // set data layout + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + std::string arch, std::string features, std::vector flags, + bool enable_fp_fusion) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (const auto &flag : split) { + setLLVMOption(flag.str(), true); + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + if (arch.empty()) { + llvm::TargetLibraryInfoImpl TLII(mod->getTargetTriple()); + TLII.disableAllFunctions(); + fam.registerPass([TLII = std::move(TLII)] { + return llvm::TargetLibraryAnalysis(TLII); + }); + } + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + setLLVMOption("print-after-all", true); + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + tuningOptions.SLPVectorization = true; + + std::string pluginFile = + mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH"); + + // We don't pass the targetMachine to the LLVM-IR pass builder, unless + // `arch` is specified. + // + // Don't set target machine in LLVM pass builder when using LLVM IR + // level plugins. LLVM IR level plugin passes typically want to insert + // calls to externally generated code (i.e. precompile a Cuda/Hip kernel + // with Clang and then insert a call to it within an instrumentation + // pass) setting the targetMachine value here can can cause a mismatch + // in the target machine between the MLIR and Clang generated kernels + // and break the lowering of some target specific intrinsics. + std::unique_ptr targetMachine = nullptr; + if (!arch.empty() && pluginFile.empty()) + targetMachine = + createTargetMachine(mod, arch, enable_fp_fusion, features); + PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions, + std::nullopt, instrCbPtr); + + if (!pluginFile.empty()) { + // TODO: Add some logging here that we inserted a pass into the LLVM + // pass pipeline + auto passPlugin = llvm::PassPlugin::Load(pluginFile); + if (!passPlugin) { + llvm::Error Err = passPlugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + passPlugin->registerPassBuilderCallbacks(pb); + } + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + bool enableAddressSanitizer = + mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN"); + if (enableAddressSanitizer) { + AddressSanitizerOptions Opts; + mpm.addPass(AddressSanitizerPass(Opts)); + } + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + // Mandatory parameters + py::arg("mod"), py::arg("opt"), + // If we want to specify the target machine, we require additional + // (optional) parameters + py::arg("arch") = "", py::arg("features") = "", + py::arg("flags") = std::vector{}, + py::arg("enable_fp_fusion") = false, + py::call_guard()); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("dump_sched_dag", [](std::string llvmIR, std::string triple, + std::string proc, std::string features, + std::vector flags, + bool enable_fp_fusion, std::string dumpFileId) { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error("failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + dumpSchedulingDAG(*module, triple, proc, features, flags, enable_fp_fusion, + dumpFileId); + }); + + m.def( + "translate_to_mir", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, std::string dumpFileId) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToMIR(*module, triple, proc, features, flags, + enable_fp_fusion, dumpFileId); + } + return py::str(obj); + }, + ret::take_ownership); + + m.def( + "translate_mir_to_asm", + [](std::string mirPath, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string result; + { + py::gil_scoped_release allow_threads; + result = translateMIRToASM(mirPath, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(result); + else + return py::str(result); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { initializeTritonLLVMTargets(); }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(Triple(dstMod->getTargetTriple())); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + if (fn.hasFnAttribute(llvm::Attribute::NoInline)) { + fn.removeFnAttr(llvm::Attribute::NoInline); + fn.removeFnAttr(llvm::Attribute::OptimizeNone); + fn.addFnAttr(llvm::Attribute::AlwaysInline); + } + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} + +void triton_stacktrace_signal_handler(void *) { + llvm::sys::PrintStackTrace(llvm::errs()); + raise(SIGABRT); +} + +void init_triton_stacktrace_hook(pybind11::module &m) { + if (mlir::triton::tools::getBoolEnv("TRITON_ENABLE_PYTHON_STACKTRACE")) { + llvm::sys::AddSignalHandler(triton_stacktrace_signal_handler, nullptr); + } +} diff --git a/third_party/mthreads/python/src/main.cc b/third_party/mthreads/python/src/main.cc new file mode 100644 index 0000000000..40b19f77ea --- /dev/null +++ b/third_party/mthreads/python/src/main.cc @@ -0,0 +1,62 @@ +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Signals.h" +#include + +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) +#define FOR_EACH_5(MACRO, X, ...) MACRO(X) FOR_EACH_4(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, N, ...) N +#define FOR_EACH_RSEQ_N() 5, 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); +void init_gluon_ir(pybind11::module &&m); +void init_linear_layout(pybind11::module &&m); +void init_native_specialize(pybind11::module &m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_stacktrace_hook(m); + init_triton_env_vars(m); + init_native_specialize(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + init_linear_layout(m.def_submodule("linear_layout")); + init_gluon_ir(m.def_submodule("gluon_ir")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/mthreads/python/src/passes.cc b/third_party/mthreads/python/src/passes.cc new file mode 100644 index 0000000000..8977b59913 --- /dev/null +++ b/third_party/mthreads/python/src/passes.cc @@ -0,0 +1,161 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Gluon/Transforms/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonInstrument/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createTritonCombineOps); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createTritonReorderBroadcast); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createTritonRewriteTensorPointer); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_descriptor_to_pointer", + createTritonRewriteTensorDescriptorToPointer); + ADD_PASS_WRAPPER_0("add_loop_unroll", createTritonLoopUnroll); + ADD_PASS_WRAPPER_0("add_triton_licm", createTritonLoopInvariantCodeMotion); + ADD_PASS_WRAPPER_0("add_loop_aware_cse", createTritonLoopAwareCSE); + ADD_PASS_OPTION_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPU, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir; + using namespace mlir::triton::gpu; + using namespace mlir::triton::instrument; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_OPTION_WRAPPER_1("add_hoist_tmem_alloc", + createTritonGPUHoistTMEMAlloc, bool); + ADD_PASS_OPTION_WRAPPER_1("add_assign_latencies", + createTritonGPUAssignLatencies, int); + ADD_PASS_WRAPPER_0("add_schedule_loops", createTritonGPUScheduleLoops); + ADD_PASS_OPTION_WRAPPER_2("add_pipeline", createTritonGPUPipeline, int, bool); + ADD_PASS_OPTION_WRAPPER_1("add_warp_specialize", + createTritonGPUAutomaticWarpSpecialization, int); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_OPTION_WRAPPER_1("add_f32_dot_tc", createTritonGPUF32DotTC, bool); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_warp_groups", + createTritonGPUAllocateWarpGroups); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemory); + ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory", + createTritonGPUGlobalScratchAllocationPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops); + ADD_PASS_WRAPPER_0("add_coalesce_async_copy", + createTritonGPUCoalesceAsyncCopy); + ADD_PASS_WRAPPER_0("add_concurrency_sanitizer", + createTritonInstrumentConcurrencySanitizer); + ADD_PASS_WRAPPER_0("add_optimize_partition_warps", + createTritonGPUOptimizePartitionWarps); +} + +void init_plugin_passes(py::module &&m) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + if (filename.empty()) + return; + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + throw TP.err2exp(result.takeError()); + + for (unsigned i = 0; i < passNames.size(); ++i) { + const char *passName = passNames.data()[i]; + + m.def(passName, [passName](mlir ::PassManager &pm) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + TritonPlugin TP(filename); + if (auto result = TP.addPass(&pm, passName); !result) + throw TP.err2exp(result.takeError()); + }); + } +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_nvvm_to_llvm", createConvertNVVMToLLVMPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", mlir::createLLVMDIScope); + ADD_PASS_WRAPPER_0("add_di_local_variable", mlir::createLLVMDILocalVariable); +} + +void init_gluon_passes(py::module &&m) { + using namespace mlir; + namespace gluon = mlir::triton::gluon; + ADD_PASS_WRAPPER_0("add_resolve_auto_encodings", + gluon::createGluonResolveAutoEncodingsPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", gluon::createGluonCanonicalize); + ADD_PASS_WRAPPER_0("add_inliner", gluon::createGluonInline); + ADD_PASS_WRAPPER_0("add_infer_coalesced_encodings", + gluon::createGluonInferCoalescedEncodingsPass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); + init_gluon_passes(m.def_submodule("gluon")); + init_plugin_passes(m.def_submodule("plugin")); +} diff --git a/third_party/mthreads/python/src/passes.h b/third_party/mthreads/python/src/passes.h new file mode 100644 index 0000000000..62f5986a07 --- /dev/null +++ b/third_party/mthreads/python/src/passes.h @@ -0,0 +1,43 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); }) + +#define ADD_PASS_OPTION_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \ + ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); }) diff --git a/third_party/mthreads/python/src/specialize.cc b/third_party/mthreads/python/src/specialize.cc new file mode 100644 index 0000000000..3449e3c900 --- /dev/null +++ b/third_party/mthreads/python/src/specialize.cc @@ -0,0 +1,584 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +namespace py = pybind11; + +using DTypePtrKey = std::pair; +using DTypeKey = Py_hash_t; + +struct DTypePtrKeyHash { + std::size_t operator()(const DTypePtrKey &k) const { + return std::hash()(k.first) ^ (std::hash()(k.second) << 1); + } +}; + +using DtypePtr2Str = + std::unordered_map; +using Dtype2Str = std::unordered_map; + +using TypeHandler = std::pair (*)(PyObject *, + PyObject *, bool, + bool, bool); +using TypeHandlerCache = std::unordered_map; + +static std::pair +specialize_arg(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align); + +static bool init_called = false; + +static PyObject *constexpr_cls = nullptr; +static PyObject *jit_callable_cls = nullptr; +static PyObject *tensor_descriptor_cls = nullptr; +static PyObject *nvidia_tensor_descriptor_cls = nullptr; +static PyObject *amd_tensor_descriptor_cls = nullptr; +static PyObject *canonicalize_dtype_fn = nullptr; +static PyObject *canonicalize_ptr_dtype_fn = nullptr; +static PyObject *torch_tensor_cls = nullptr; + +static PyObject *i32_str = nullptr; +static PyObject *i64_str = nullptr; +static PyObject *u64_str = nullptr; +static PyObject *fp32_str = nullptr; +static PyObject *u1_str = nullptr; +static PyObject *D_str = nullptr; +static PyObject *constexpr_str = nullptr; +static PyObject *empty_str = nullptr; +static PyObject *nvTmaDesc_str = nullptr; + +static PyObject *base_attr = nullptr; +static PyObject *data_ptr_attr = nullptr; +static PyObject *dtype_attr = nullptr; +static PyObject *cache_key_attr = nullptr; +static PyObject *_fields_attr = nullptr; +static PyObject *block_shape_attr = nullptr; +static PyObject *layout_attr = nullptr; +static PyObject *has_native_tensor_spec_attr = nullptr; +static PyObject *get_tensor_spec_attr = nullptr; +static PyObject *align_kwarg = nullptr; + +static DtypePtr2Str dtype_ptr2str; +static Dtype2Str dtype2str; +static TypeHandlerCache type_handler_cache; + +// Wrappers to make steal and borrow slightly simpler. We use raw CPython API +// with py::object to handle decref, as using the pybind11 APIs adds exception +// handling overhead which is quite significant here. +py::object from_new_ref(py::handle val) { + return py::reinterpret_steal(val); +} +py::object from_borrowed_ref(py::handle val) { + return py::reinterpret_borrow(val); +} + +PyObject *intern_from_string(const char *str) { + PyObject *obj = PyUnicode_InternFromString(str); + if (!obj) + throw py::error_already_set(); + return obj; +} + +PyObject *import_from(const char *module_name, const char *var_name) { + py::object var = py::module_::import(module_name).attr(var_name); + return var.release().ptr(); +} + +void init_interned_strings() { + i32_str = intern_from_string("i32"); + i64_str = intern_from_string("i64"); + u64_str = intern_from_string("u64"); + fp32_str = intern_from_string("fp32"); + u1_str = intern_from_string("u1"); + D_str = intern_from_string("D"); + constexpr_str = intern_from_string("constexpr"); + empty_str = intern_from_string(""); + nvTmaDesc_str = intern_from_string("nvTmaDesc"); + + base_attr = intern_from_string("base"); + data_ptr_attr = intern_from_string("data_ptr"); + dtype_attr = intern_from_string("dtype"); + cache_key_attr = intern_from_string("cache_key"); + _fields_attr = intern_from_string("_fields"); + block_shape_attr = intern_from_string("block_shape"); + layout_attr = intern_from_string("layout"); + has_native_tensor_spec_attr = + intern_from_string("supports_native_tensor_specialization"); + get_tensor_spec_attr = intern_from_string("get_tensor_specialization"); + + align_kwarg = py::make_tuple("align").release().ptr(); +} + +void init_type_handler_cache(); + +bool init_globals() noexcept try { + // Import releavant symbols + jit_callable_cls = import_from("triton.runtime.jit", "JITCallable"); + tensor_descriptor_cls = + import_from("triton.tools.tensor_descriptor", "TensorDescriptor"); + nvidia_tensor_descriptor_cls = import_from( + "triton.experimental.gluon.nvidia.hopper", "TensorDescriptor"); + amd_tensor_descriptor_cls = + import_from("triton.experimental.gluon.amd.gfx1250", "TensorDescriptor"); + + auto m_canonicalize = py::module_::import("triton._utils"); + canonicalize_dtype_fn = import_from("triton._utils", "canonicalize_dtype"); + canonicalize_ptr_dtype_fn = + import_from("triton._utils", "canonicalize_ptr_dtype"); + constexpr_cls = import_from("triton.language", "constexpr"); + + try { + torch_tensor_cls = import_from("torch", "Tensor"); + } catch (py::error_already_set &e) { + } + + init_interned_strings(); + init_type_handler_cache(); + + init_called = true; + return true; +} catch (py::error_already_set &e) { + e.restore(); + return false; +} + +std::pair specialize_tensordesc(PyObject *arg, + bool has_layout) { + auto base = from_new_ref(PyObject_GetAttr(arg, base_attr)); + if (!base) + return {}; + + auto dtype = from_new_ref(PyObject_GetAttr(base.ptr(), dtype_attr)); + if (!dtype) + return {}; + + PyObject *type_str; + Py_hash_t dtype_hash = PyObject_Hash(dtype.ptr()); + if (dtype_hash == -1) + return {}; + DTypeKey dsk{dtype_hash}; + auto it = dtype2str.find(dsk); + if (it != dtype2str.end()) { + type_str = it->second; + } else { + auto res = from_new_ref(PyObject_CallFunctionObjArgs(canonicalize_dtype_fn, + dtype.ptr(), nullptr)); + if (!res) + return {}; + dtype2str[dsk] = res.ptr(); + type_str = res.release().ptr(); + } + + std::string desc_cstr; + desc_cstr.reserve(128); + desc_cstr = "tensordesc<"; + auto dtype_str = from_new_ref(PyObject_Str(type_str)); + if (!dtype_str) + return {}; + + const char *dtype_cstr = PyUnicode_AsUTF8(dtype_str.ptr()); + if (!dtype_cstr) + return {}; + desc_cstr += dtype_cstr; + + auto block_shape_obj = from_new_ref(PyObject_GetAttr(arg, block_shape_attr)); + if (!block_shape_obj) + return {}; + auto block_shape_list = from_new_ref(PySequence_List(block_shape_obj.ptr())); + if (!block_shape_list) + return {}; + auto block_shape_str = from_new_ref(PyObject_Str(block_shape_list.ptr())); + if (!block_shape_str) + return {}; + const char *block_shape_cstr = PyUnicode_AsUTF8(block_shape_str.ptr()); + if (!block_shape_cstr) + return {}; + desc_cstr += block_shape_cstr; + + if (has_layout) { + auto layout_obj = from_new_ref(PyObject_GetAttr(arg, layout_attr)); + if (!layout_obj) + return {}; + auto layout_repr = from_new_ref(PyObject_Repr(layout_obj.ptr())); + if (!layout_repr) + return {}; + desc_cstr += ","; + const char *layout_cstr = PyUnicode_AsUTF8(layout_repr.ptr()); + if (!layout_cstr) + return {}; + desc_cstr += layout_cstr; + } + + desc_cstr += ">"; + auto type_str_result = from_new_ref(PyUnicode_FromString(desc_cstr.c_str())); + if (!type_str_result) + return {}; + + return {std::move(type_str_result), py::none()}; +} + +std::pair handle_long_type(PyObject *backend, + PyObject *arg, bool is_const, + bool specialize_value, + bool align) { + int overflow; + long long val = PyLong_AsLongLongAndOverflow(arg, &overflow); + if (PyErr_Occurred()) { + return {}; + } + + if (specialize_value && (val == 1)) { + return {from_borrowed_ref(constexpr_str), from_borrowed_ref(arg)}; + } + + py::handle type_str; + py::handle key_obj; + if (overflow == 0) { + type_str = (val >= INT32_MIN && val <= INT32_MAX) ? i32_str : i64_str; + if (specialize_value) { + key_obj = (align && ((val & 15) == 0)) ? D_str : empty_str; + } + } else { + unsigned long long val_64 = PyLong_AsUnsignedLongLong(arg); + if (PyErr_Occurred()) { + // this runs into an edge-case where the Python reference + // returns i64 as type and alignment of the value despite + // not being representable as such which at kernel launch later + // will throw an OverflowError nevertheless, here we throw + // OverflowError immediately + PyErr_SetString(PyExc_OverflowError, + "integer to be specialized too large to represent"); + return {}; + } + type_str = u64_str; + if (specialize_value) { + key_obj = (align && ((val_64 & 15) == 0)) ? D_str : empty_str; + } + } + if (!key_obj) { + return {from_borrowed_ref(type_str), py::none()}; + } + return {from_borrowed_ref(type_str), from_borrowed_ref(key_obj)}; +} + +std::pair handle_tensor(PyObject *backend, + PyObject *arg, bool is_const, + bool specialize_value, + bool align) { + // handle type_str specialization of a tensor + auto dtype = from_new_ref(PyObject_GetAttr(arg, dtype_attr)); + if (!dtype) + return {}; + + Py_hash_t dtype_hash = PyObject_Hash(dtype.ptr()); + if (dtype_hash == -1) + return {}; + + DTypePtrKey dsk{dtype_hash, is_const}; + auto it = dtype_ptr2str.find(dsk); + + py::handle type_str; + if (it != dtype_ptr2str.end()) { + type_str = it->second; + } else { + auto canon_res = + PyObject_CallFunctionObjArgs(canonicalize_ptr_dtype_fn, dtype.ptr(), + is_const ? Py_True : Py_False, nullptr); + if (!canon_res) + return {}; + dtype_ptr2str[dsk] = canon_res; + type_str = canon_res; + } + + // handle alignment specialization of a tensor + if (!specialize_value) { + return {from_borrowed_ref(type_str), py::none()}; + } + + bool native_impl_available = false; + auto native_spec_obj = + from_new_ref(PyObject_GetAttr(backend, has_native_tensor_spec_attr)); + if (native_spec_obj) { + native_impl_available = PyObject_IsTrue(native_spec_obj.ptr()); + } else { + PyErr_Clear(); + // on error we fall back to native_impl_available = false gracefully + } + + py::object key; + if (native_impl_available) { + auto data_ptr_result = + from_new_ref(PyObject_CallMethodNoArgs(arg, data_ptr_attr)); + if (!data_ptr_result) + return {}; + + auto data_ptr = PyLong_AsUnsignedLongLong(data_ptr_result.ptr()); + if (PyErr_Occurred()) + return {}; + + auto key_obj = (align && ((data_ptr & 15) == 0)) ? D_str : empty_str; + key = from_borrowed_ref(key_obj); + } else { + PyObject *args[3] = {backend, arg, align ? Py_True : Py_False}; + PyObject *kwnames = align_kwarg; + key = from_new_ref( + PyObject_VectorcallMethod(get_tensor_spec_attr, args, 2, kwnames)); + if (!key) + return {}; + } + + return {from_borrowed_ref(type_str), std::move(key)}; +} + +std::pair handle_bool_type(PyObject *backend, + PyObject *arg, bool is_const, + bool specialize_value, + bool align) { + return {from_borrowed_ref(u1_str), py::none()}; +} + +std::pair +handle_float_type(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align) { + return {from_borrowed_ref(fp32_str), py::none()}; +} + +std::pair +handle_tensor_descriptor(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align) { + return specialize_tensordesc(arg, false); +} + +std::pair +handle_gluon_tensor_descriptor(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align) { + return specialize_tensordesc(arg, true); +} + +std::pair +handle_constexpr_type(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align) { + return {from_borrowed_ref(constexpr_str), from_borrowed_ref(arg)}; +} + +std::pair +handle_jit_callable(PyObject *backend, PyObject *arg, bool is_const, + bool specialize_value, bool align) { + auto cache_key = from_new_ref(PyObject_GetAttr(arg, cache_key_attr)); + if (!cache_key) + return {}; + return {from_borrowed_ref(constexpr_str), std::move(cache_key)}; +} + +std::pair handle_tuple(PyObject *backend, PyObject *arg, + bool is_const, + bool specialize_value, + bool align) { + Py_ssize_t size = PyTuple_GET_SIZE(arg); + if (size == 0) { + // return tuple of empty tuples as in python reference + return {from_borrowed_ref(arg), from_borrowed_ref(arg)}; + } + + bool is_namedtuple = PyObject_HasAttr(arg, _fields_attr); + auto tuple_type = Py_TYPE(arg); + + // Create tuples directly instead of lists + auto tys_tuple = from_new_ref(PyTuple_New(size)); + if (!tys_tuple) + return {}; + + auto keys_tuple = from_new_ref(PyTuple_New(size)); + if (!keys_tuple) + return {}; + + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject *item = PyTuple_GET_ITEM(arg, i); // Borrowed reference + // python reference calls specialize recursively with default arguments set + // currently this is is_const=False, specialize_value=True, align=True + auto [type, key] = specialize_arg(backend, item, false, true, true); + if (!type || !key) + return {}; + // Steals reference + PyTuple_SET_ITEM(tys_tuple.ptr(), i, type.release().ptr()); + PyTuple_SET_ITEM(keys_tuple.ptr(), i, key.release().ptr()); + } + + if (is_namedtuple) { + tys_tuple = from_new_ref( + PyObject_CallObject((PyObject *)tuple_type, tys_tuple.ptr())); + if (!tys_tuple) + return {}; + keys_tuple = from_new_ref( + PyObject_CallObject((PyObject *)tuple_type, keys_tuple.ptr())); + if (!keys_tuple) + return {}; + } + + return {std::move(tys_tuple), std::move(keys_tuple)}; +} + +// initialize type handler which returns specialize impelemntations based on +// type(arg) +void init_type_handler_cache() { + // Python Types (int, bool, float, tuple) + type_handler_cache[&PyLong_Type] = handle_long_type; + type_handler_cache[&PyBool_Type] = handle_bool_type; + type_handler_cache[&PyFloat_Type] = handle_float_type; + type_handler_cache[&PyTuple_Type] = handle_tuple; + + // torch.Tensor + if (torch_tensor_cls && PyType_Check(torch_tensor_cls)) { + type_handler_cache[(PyTypeObject *)torch_tensor_cls] = handle_tensor; + } + // TensorDescriptor + if (tensor_descriptor_cls && PyType_Check(tensor_descriptor_cls)) { + type_handler_cache[(PyTypeObject *)tensor_descriptor_cls] = + handle_tensor_descriptor; + } + // GluonTensorDescriptor + if (nvidia_tensor_descriptor_cls && + PyType_Check(nvidia_tensor_descriptor_cls)) { + type_handler_cache[(PyTypeObject *)nvidia_tensor_descriptor_cls] = + handle_gluon_tensor_descriptor; + } + if (amd_tensor_descriptor_cls && PyType_Check(amd_tensor_descriptor_cls)) { + type_handler_cache[(PyTypeObject *)amd_tensor_descriptor_cls] = + handle_gluon_tensor_descriptor; + } + // constexpr + if (constexpr_cls && PyType_Check(constexpr_cls)) { + type_handler_cache[(PyTypeObject *)constexpr_cls] = handle_constexpr_type; + } + // JITCallable + if (jit_callable_cls && PyType_Check(jit_callable_cls)) { + type_handler_cache[(PyTypeObject *)jit_callable_cls] = handle_jit_callable; + } +} + +// specialization logic without passing of objects from Python (to be called in +// specialize_impl only) +std::pair specialize_arg(PyObject *backend, + PyObject *arg, bool is_const, + bool specialize_value, + bool align) { + // fast-path for default types + PyTypeObject *arg_type = Py_TYPE(arg); + auto it = type_handler_cache.find(arg_type); + if (it != type_handler_cache.end()) { + return it->second(backend, arg, is_const, specialize_value, align); + } + + // separate handling of None + if (Py_IsNone(arg)) { + return {from_borrowed_ref(constexpr_str), py::none()}; + } + + // handling of sublcasses of tuples + if (PyTuple_Check(arg)) { + return handle_tuple(backend, arg, is_const, specialize_value, align); + } + + // fallback paths checking full inheritance + if (PyObject_IsInstance(arg, constexpr_cls)) { + return handle_constexpr_type(backend, arg, is_const, specialize_value, + align); + } + + if (PyObject_IsInstance(arg, tensor_descriptor_cls)) { + return handle_tensor_descriptor(backend, arg, is_const, specialize_value, + align); + } + + if (PyObject_IsInstance(arg, nvidia_tensor_descriptor_cls)) { + return handle_gluon_tensor_descriptor(backend, arg, is_const, + specialize_value, align); + } + + if (PyObject_IsInstance(arg, amd_tensor_descriptor_cls)) { + return handle_gluon_tensor_descriptor(backend, arg, is_const, + specialize_value, align); + } + + if (PyObject_IsInstance(arg, jit_callable_cls)) { + return handle_jit_callable(backend, arg, is_const, specialize_value, align); + } + + // fallback paths checking attributes directly + if (PyObject_HasAttr(arg, data_ptr_attr)) { + return handle_tensor(backend, arg, is_const, specialize_value, align); + } + + // fallback for default types + if (PyLong_Check(arg)) { + return handle_long_type(backend, arg, is_const, specialize_value, align); + } + if (PyFloat_Check(arg)) { + return handle_float_type(backend, arg, is_const, specialize_value, align); + } + + return {}; +} + +// main entry-point from Python implementing specialization logic natively +PyObject *specialize_impl(PyObject *self, PyObject *const *args, + Py_ssize_t nargs) { + if (!init_called) { + if (!init_globals()) { + return nullptr; + } + } + + if (nargs != 5) { + PyErr_SetString(PyExc_TypeError, + "native_specialize_impl expected 5 arguments"); + return nullptr; + } + + PyObject *backend = args[0]; + PyObject *arg = args[1]; + int is_const = PyObject_IsTrue(args[2]); + int specialize_value = PyObject_IsTrue(args[3]); + int align = PyObject_IsTrue(args[4]); + + if (is_const == -1 || specialize_value == -1 || align == -1) { + PyErr_SetString(PyExc_TypeError, "native_specialize_impl expected boolean " + "arguments for args2, args3, args4"); + return nullptr; + } + + auto [type, key] = + specialize_arg(backend, arg, is_const, specialize_value, align); + + // check if specialization failed + if (!type || !key) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_TypeError, "failed to specialize argument of type: %s", + Py_TYPE(arg)->tp_name); + } + return nullptr; + } + + return PyTuple_Pack(2, type.ptr(), key.ptr()); +} + +static PyMethodDef module_methods[] = { + {"native_specialize_impl", (PyCFunction)specialize_impl, METH_FASTCALL, + nullptr}, + {nullptr, nullptr, 0, nullptr} // sentinel +}; + +} // anonymous namespace + +void init_native_specialize(pybind11::module &m) { + // add functions to module + PyModule_AddFunctions(m.ptr(), module_methods); +} diff --git a/third_party/mthreads/python/test/conftest.py b/third_party/mthreads/python/test/conftest.py new file mode 100644 index 0000000000..df153baf38 --- /dev/null +++ b/third_party/mthreads/python/test/conftest.py @@ -0,0 +1,66 @@ +import pytest +import tempfile + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + from triton import knobs + + with knobs.cache.scope(), knobs.runtime.scope(): + knobs.cache.dir = tmpdir + yield tmpdir + + +@pytest.fixture +def fresh_knobs(): + """ + Resets all knobs except ``build``, ``nvidia``, and ``amd`` (preserves + library paths needed to compile kernels). + """ + from triton._internal_testing import _fresh_knobs_impl + fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"}) + try: + yield fresh_function() + finally: + reset_function() + + +@pytest.fixture +def fresh_knobs_including_libraries(): + """ + Resets ALL knobs including ``build``, ``nvidia``, and ``amd``. + Use for tests that verify initial values of these knobs. + """ + from triton._internal_testing import _fresh_knobs_impl + fresh_function, reset_function = _fresh_knobs_impl() + try: + yield fresh_function() + finally: + reset_function() + + +@pytest.fixture +def with_allocator(): + import triton + from triton.runtime._allocation import NullAllocator + from triton._internal_testing import default_alloc_fn + + triton.set_allocator(default_alloc_fn) + try: + yield + finally: + triton.set_allocator(NullAllocator()) diff --git a/third_party/mthreads/python/test/unit/language/print_helper.py b/third_party/mthreads/python/test/unit/language/print_helper.py new file mode 100644 index 0000000000..d0c986400e --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/print_helper.py @@ -0,0 +1,170 @@ +import sys +import uuid + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_cast(BLOCK: tl.constexpr): + x = tl.arange(0, BLOCK) + 128 + tl.device_print("x: ", x.to(tl.uint8)) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_scalar(SCALAR): + x = tl.load(SCALAR) + # Triton should add a space after this prefix. + print("x:", x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +@triton.jit +def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr): + off_x = tl.arange(0, BLOCK_SIZE_X) + off_y = tl.arange(0, BLOCK_SIZE_Y) + x = tl.load(X + off_x[:, None] * BLOCK_SIZE_Y + off_y[None, :]) + tl.device_print("", x) + + +def test_print(func: str, data_type: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_scalar": + scalar = torch.tensor(42, dtype=x.dtype, device=device) + kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_negative": + x = -x + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_uint": + x = torch.arange((1 << 31), (1 << 31) + N, device=device).to(getattr(torch, data_type)) + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_uint_cast": + kernel_device_print_cast[(1, )](num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_2d_tensor": + BLOCK_SIZE_X = num_warps + BLOCK_SIZE_Y = get_current_target_warp_size() + x_2d_tensor = x.reshape((BLOCK_SIZE_X, BLOCK_SIZE_Y)) + kernel_print_2d_tensor[(1, )](x_2d_tensor, y, num_warps=num_warps, BLOCK_SIZE_X=BLOCK_SIZE_X, + BLOCK_SIZE_Y=BLOCK_SIZE_Y) + else: + assert f"Unknown kernel: {func}" + + excluded_funcs = { + "print_no_arg", "no_arg_print", "device_print_large", "print_multiple_args", "device_print_multiple_args", + "device_print_pointer", "device_print_scalar", "device_print_2d_tensor", "device_print_uint_cast" + } + if func not in excluded_funcs: + assert_close(y, x) + + # Wait until driver complete all the jobs for the device_print, especially test_subprocess + # require this which captures stdout when child exits. + getattr(torch, device).synchronize() + + +if __name__ == "__main__": + fn = globals()[sys.argv[1]] + fn(*sys.argv[2:]) diff --git a/third_party/mthreads/python/test/unit/language/test_annotations.py b/third_party/mthreads/python/test/unit/language/test_annotations.py new file mode 100644 index 0000000000..5032665d03 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_annotations.py @@ -0,0 +1,85 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest +import numpy as np + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X + v, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + if not signed and width < 64: + assert "arith.extui %v" in h.asm["ttir"] + assert f'%v: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass + + +# Test float annotations are properly respected +@pytest.mark.parametrize( + ("dtype", "test_val"), + [(dtype, test_val) + for dtype in [tl.float16, tl.bfloat16, tl.float32, tl.float64] + for test_val in [0.0, 42.0, float("inf"), float("nan")]], +) +def test_float_annotation(device, dtype, test_val): + + @triton.jit + @annotated_function(val=dtype) + def _kernel(ptr, val): + tl.static_assert(val.dtype == dtype) + tl.store(ptr, val) + + ptr = torch.empty(1, device=device, dtype=torch.float32) + h = _kernel[(1, )](ptr, test_val) + np.testing.assert_allclose(ptr.cpu().numpy(), [test_val], atol=1e-6) + + # Check that the type is properly emitted in the IR + if dtype == tl.float16: + assert "%val: f16" in h.asm["ttir"] + assert "arith.extf %val : f16 to f32" in h.asm["ttir"] + elif dtype == tl.bfloat16: + assert "%val: bf16" in h.asm["ttir"] + assert "arith.extf %val : bf16 to f32" in h.asm["ttir"] + elif dtype == tl.float32: + assert "%val: f32" in h.asm["ttir"] + elif dtype == tl.float64: + assert "%val: f64" in h.asm["ttir"] + assert "arith.truncf %val : f64 to f32" in h.asm["ttir"] diff --git a/third_party/mthreads/python/test/unit/language/test_block_pointer.py b/third_party/mthreads/python/test/unit/language/test_block_pointer.py new file mode 100644 index 0000000000..aff7a29d87 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,118 @@ +import pytest +import torch + +import triton +import triton.language as tl +from test_core import check_type_supported + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr, + TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + if TEST_LOWER_BOUND: + offset = -N + elif TEST_UPPER_BOUND: + offset = N + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + if PADDING_OPTION is None: + a = tl.load(a_block_ptr, boundary_check=(0, )) + else: + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ # + (dtypes_str, n, padding, boundary_check) # + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), + ("float32", "float32"), ("bfloat16", "bfloat16")) + for n in (64, 128, 256, 512, 1024) + for padding in (None, "zero", "nan") # + for boundary_check in (None, "lower", "upper") +]) +def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[1] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + check_type_supported(src_dtype, device) + check_type_supported(dst_dtype, device) + if src_dtype_str in ("bool", "int16", "int32"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option, + TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper") + a.to(dst_dtype) + if (boundary_check == "lower") or (boundary_check == "upper"): + assert torch.all(b == 0) + else: + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + elif padding_option == "nan": + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/mthreads/python/test/unit/language/test_compile_errors.py b/third_party/mthreads/python/test/unit/language/test_compile_errors.py new file mode 100644 index 0000000000..f1a3077973 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,567 @@ +import contextlib +import pytest +import os + +import torch +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna4, is_musa, is_musa_ph1 + + +def format_exception(type, value, tb): + list_msg = traceback.format_exception(type, value, tb, chain=False) + return "\n".join(list_msg) + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "is not defined" in err_msg, "error should mention the undefined variable" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 0" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + print(err_msg) + assert "at 2:4:" in err_msg, "error should point to the static_assert call" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + -(0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the `not`" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 1.0" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert "at 2:4:" in inner, "error should point to xyz" + assert "" not in inner + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 3:4" in outer, "error should point to the nested_call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert f"{os.sep}core.py" in inner, "error should point inside core.py" + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in outer, "error should point to expand_dims call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_not_const_annotate_no_err(): + + @triton.jit + def kernel(N: int = 1): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 1:0:" in str(e.value.__cause__), "error should point to function definition" + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert False, "Using a constexpr annotated global variable should not be allowed" + except CompilationError as e: + assert "Cannot access global variable" in str(e) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) + + +def test_defaults_assign_no_err(): + + @triton.jit + def kernel(a=1, B: tl.constexpr = ""): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) + + +def test_where_warning(fresh_triton_cache): + + @triton.jit + def kernel(): + a = tl.full((64, ), 0, tl.uint32) + b = tl.full((64, ), 1, tl.float32) + c = tl.full((64, ), 2, tl.float32) + tl.where(a, b, c) + + with pytest.warns(UserWarning): + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) +def test_fp8_support(fresh_triton_cache, dtype): + warning_dtypes = [] + supported_dtypes = [tl.float8e5] + if is_cuda(): + cc = torch.cuda.get_device_capability(0) + supported_dtypes.append(tl.float8e4b15) + if cc >= (9, 0): + warning_dtypes.append(tl.float8e4b15) + if cc >= (8, 9): + supported_dtypes.append(tl.float8e4nv) + elif is_hip(): + supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] + if is_hip_cdna4(): + warning_dtypes += [tl.float8e4b8, tl.float8e5b16] + elif is_musa(): + if is_musa_ph1(): + supported_dtypes += [tl.float8e4nv] + + @triton.jit + def dtype_kernel(dtype: tl.constexpr): + a = tl.full((64, 64), 0.0, dtype) + tl.dot(a, a) + + if dtype in warning_dtypes: + if is_cuda(): + ctx = pytest.warns(UserWarning, + match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures") + elif is_hip_cdna4(): + ctx = pytest.warns(UserWarning, match=r"AMD gfx942 specific and not supported on gfx950") + elif dtype in supported_dtypes: + ctx = contextlib.nullcontext() + else: + ctx = pytest.raises(CompilationError, match="") + + with ctx as e: + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + + if dtype not in supported_dtypes: + try: + assert ("not supported in this architecture" in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16]) +def test_min_dot_size(dtype): + error_msg = "Input shapes should have " + if is_cuda(): + if dtype.primitive_bitwidth == 8: + error_msg += "M >= 1, N >= 1 and K >= 32" + else: + error_msg = "M >= 1, N >= 1 and K >= 16" + elif is_hip(): + # hip supports arbitrary sizes + error_msg = None + else: + pytest.skip("Test only supported on CUDA and HIP") + + @triton.jit + def dot_kernel(dtype: tl.constexpr): + SIZE: tl.constexpr = 8 + a = tl.full((SIZE, SIZE), 0.0, dtype) + b = tl.full((SIZE, SIZE), 0.0, dtype) + tl.dot(a, b) + + if error_msg is None: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + else: + with pytest.raises(CompilationError) as e: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + try: + assert (error_msg in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_max_num_imprecise_acc_limit(): + + @triton.jit + def dot_kernel(): + SIZE: tl.constexpr = 64 + a = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + b = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + tl.dot(a, b, max_num_imprecise_acc=128) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) + try: + assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") + except AssertionError as assertion_err: + raise assertion_err from e.value + + +extra_words = "These are extra words in the error message." + + +@triton.must_use_result(extra_words) +@triton.jit +def cube(x): + return x * x * x + + +def test_unused_result(): + + @triton.jit + def evil_cube_kernel(): + a = tl.full((64, 64), 0.0, tl.float32) + cube(a) + + @triton.jit + def good_cube_kernel(): + a = tl.full((64, 64), 0.0, tl.float32) + a = cube(a) + + triton.compile(triton.compiler.ASTSource(fn=good_cube_kernel, signature={}, constexprs={})) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=evil_cube_kernel, signature={}, constexprs={})) + + expected_err_msg = "The result of cube is not being used. " + extra_words + obtained_err_msg = str(e.value).split('\n')[-1] + + assert expected_err_msg == obtained_err_msg + + +@tl.core._aggregate +class Square: + x: tl.tensor + + @triton.constexpr_function + def __init__(self, x): + self.x = x + + @triton.must_use_result + @triton.constexpr_function + def power(self): + return 2 + + @triton.must_use_result + @triton.jit + def compute(self): + return self.x * self.x + + +def test_bound_unused_result(): + + @triton.jit + def evil_square_kernel(): + a = Square(tl.full((64, 64), 0.0, tl.float32)) + a.compute() + + @triton.jit + def good_square_kernel(): + a = Square(tl.full((64, 64), 0.0, tl.float32)) + a = a.compute() + + triton.compile(triton.compiler.ASTSource(fn=good_square_kernel, signature={}, constexprs={})) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=evil_square_kernel, signature={}, constexprs={})) + + assert "The result of a.compute is not being used" in str(e.value) + + @triton.jit + def evil_power_kernel(): + a = Square(tl.full((64, 64), 0.0, tl.float32)) + a.power() + + @triton.jit + def good_power_kernel(): + a = Square(tl.full((64, 64), 0.0, tl.float32)) + a = a.power() + + triton.compile(triton.compiler.ASTSource(fn=good_power_kernel, signature={}, constexprs={})) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=evil_power_kernel, signature={}, constexprs={})) + + assert "The result of a.power is not being used" in str(e.value) + + +def test_err_constexpr_and_do_not_specialize(): + + @triton.jit(do_not_specialize=["N"]) + def kernel(N: tl.constexpr): + pass + + with pytest.raises(CompilationError, match="N marked as constexpr and listed in do_not_specialize"): + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={"N": 5})) + + with pytest.raises(CompilationError, match="N marked as constexpr and listed in do_not_specialize"): + kernel[(1, )](5) + + +def test_dot_scaled_shape_verification(fresh_triton_cache): + + @triton.jit + def kernel(): + M: tl.constexpr = 32 + K: tl.constexpr = 64 + N: tl.constexpr = 32 + a = tl.full((M, K), 0, tl.uint8) + b = tl.full((K, N), 0, tl.uint8) + lhs_scale_wrong = tl.full((M, 4), 0, tl.uint8) + rhs_scale = tl.full((N, 2), 0, tl.uint8) + acc = tl.full((M, N), 0.0, tl.float32) + tl.dot_scaled(a, lhs_scale_wrong, "e5m2", b, rhs_scale, "e5m2", acc, False, True, True, tl.float32) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + assert str(e.value.__cause__) == "lhs_scale must be a tensor of shape [..., 32, 2]. Got ['32', '4']" diff --git a/third_party/mthreads/python/test/unit/language/test_compile_only.py b/third_party/mthreads/python/test/unit/language/test_compile_only.py new file mode 100644 index 0000000000..7ec025072e --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_compile_only.py @@ -0,0 +1,201 @@ +import pytest +import triton +import triton.language as tl +from triton.backends import backends +from triton.backends.compiler import GPUTarget +import re +from triton.compiler import ASTSource + + +def _has_cuda_sm100_compile_only_backend() -> bool: + target = GPUTarget("cuda", 100, 32) + return sum(1 for backend in backends.values() if backend.compiler.supports_target(target)) == 1 + + +requires_cuda_sm100_backend = pytest.mark.skipif( + not _has_cuda_sm100_compile_only_backend(), + reason="Requires NVIDIA CUDA backend with SM100 compile-only support", +) + + +@requires_cuda_sm100_backend +def test_compile_only_sm100() -> None: + + @triton.jit + def kernel_add(a, b, c): + idx = tl.arange(0, 32) + tl.store(c + idx, tl.load(a + idx) + tl.load(b + idx)) + + k = triton.compile( + triton.compiler.ASTSource(fn=kernel_add, signature={"a": "*fp32", "b": "*fp32", "c": "*fp32"}, constexprs={}), + target=GPUTarget("cuda", 100, 32)) + ptx = k.asm["ptx"] + assert ".target sm_100a" in ptx + assert ".address_size 64" in ptx + assert k.asm["cubin"] != b"" + + +@requires_cuda_sm100_backend +def test_compile_only_dot() -> None: + + @triton.jit + def simple_dot(a_base, b_base, out): + SIZE: tl.constexpr = 64 + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + b_ptr = b_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource(fn=simple_dot, signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16"}, + constexprs={}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + pattern = (r"%(?P
\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=A)" + r"(.|\n)*?" + r"%(?P\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=B)" + r"(.|\n)*?" + r"%(?P\w+) = ttng\.tmem_alloc" + r"(.|\n)*?" + r"ttng\.tc_gen5_mma %(?P=A_SHMEM), %(?P=B_SHMEM), %(?P=TMEM_BASE)" + r"(.|\n)*?" + r"ttng\.tmem_load %(?P=TMEM_BASE)") + + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + + ptx = k.asm["ptx"] + pattern = (r"mov\.b32 %r(?P\d+), global_smem;" + r"(.|\n)*" + r"tcgen05\.alloc\.cta_group::1\.sync\.aligned\.shared::cta\.b32 \[%r(?P=G)], 64" + r"(.|\n)*" + r"tcgen05\.relinquish_alloc_permit\.cta_group::1\.sync\.aligned" + r"(.|\n)*" + r"tcgen05\.st\.sync\.aligned\.16x32bx2.x32.b32" + r"(.|\n)*" + r"tcgen05\.mma\.cta_group::1.kind::f16" + r"(.|\n)*" + r"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64" + r"(.|\n)*" + r"mbarrier.try_wait.parity.shared::cta.b64" + r"(.|\n)*" + r"tcgen05.ld.sync.aligned.16x32bx2.x32.b32" + r"(.|\n)*" + r"tcgen05.wait::ld.sync.aligned") + assert re.search(pattern, str(ptx)), "The PTX does not match the expected pattern." + assert k.asm["cubin"] != b"" + + +@requires_cuda_sm100_backend +def test_compile_only_k_loop() -> None: + + @triton.jit + def k_loop(a_base, b_base, out, k_tiles): + SIZE: tl.constexpr = 128 + offs_k = tl.arange(0, SIZE) + c = tl.zeros((SIZE, SIZE), dtype=tl.float32) + for k in range(k_tiles): + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + offs_k[None, :] + b_ptr = b_base + offs_k[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + offs_k = offs_k + SIZE + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c += tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource(fn=k_loop, + signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16", "k_tiles": + "i32"}, constexprs={}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + + pattern = (r"%(?P\w+) = arith.constant dense<0.000000e\+00>" + r"(.|\n)*?" + r"%(?P\w+) = ttng\.tmem_alloc (%(?P=TMEM_BASE))?" + r"(.|\n)*?" + r"scf\.for" + r"(.|\n)*?" + r"%(?P\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=A)" + r"(.|\n)*?" + r"%(?P\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=B)" + r"(.|\n)*?" + r"ttng\.tc_gen5_mma %(?P=A_SHMEM), %(?P=B_SHMEM), %(?P=TMEM)" + r"(.|\n)*?" + r"scf\.yield") + + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + assert k.asm["cubin"] != b"" + + +@requires_cuda_sm100_backend +def test_compile_only_dot_mxfp() -> None: + + @triton.jit + def simple_dot_mxfp(a_base, b_base, a_scale, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * PACKED_BLOCK_K_A + tl.arange(0, PACKED_BLOCK_K_A)[None, :] + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + a_scale = tl.load(scale_a_ptr) + b_scale = tl.load(scale_b_ptr) + c = tl.dot_scaled(a, a_scale, "e4m3", b, b_scale, "e4m3") + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource( + fn=simple_dot_mxfp, signature={ + "a_base": "*u8", "b_base": "*u8", "a_scale": "*u8", "b_scale": "*u8", "out": "*fp32", "BLOCK_M": + "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr" + }, constexprs={"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + pattern = (r"ttng.tc_gen5_mma_scaled (.*) lhs = e4m3 rhs = e4m3") + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + + ptx = k.asm["ptx"] + pattern = (r"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X") + assert re.search(pattern, str(ptx)), "The PTX does not match the expected pattern." + assert k.asm["cubin"] != b"" + + +def test_signature_ordering(): + """ + Checks that ASTSource always uses the argument order from + fn.arg_names and not the signature. + """ + + @triton.jit + def kernel(a, o, N: tl.constexpr): + tl.store(o + N, tl.load(a + N)) + + # Add the arguments so the order always differs + # from the order in fn.arg_names. + signature = {} + signature["N"] = "constexpr" + signature["a"] = "*fp32" + signature["o"] = "*fp32" + src = ASTSource( + fn=kernel, + constexprs={"N": 32}, + signature=signature, + ) + target = triton.runtime.driver.active.get_current_target() + triton.compile(src=src, target=target) diff --git a/third_party/mthreads/python/test/unit/language/test_conversions.py b/third_party/mthreads/python/test/unit/language/test_conversions.py new file mode 100644 index 0000000000..e64e16ea1a --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_conversions.py @@ -0,0 +1,444 @@ +# fmt: off + + +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_hip_rdna4, is_musa, is_musa_ph1 + + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + + +def sanitize_fnuz_special_value(x): + # FNUZ fp8 uses the raw bit-pattern 0x80, which is stored through int8 + # tensors as the signed value -128. MUSA int8 comparisons against the + # literal 0x80 do not match that storage value, so normalize via the + # signed representation to keep the emulation inputs stable across + # backends. + return torch.where(x == torch.iinfo(x.dtype).min, torch.zeros_like(x), x) + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = sanitize_fnuz_special_value(dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = sanitize_fnuz_special_value(dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst_to_float32 = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + src_emulated_to_float32 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(src_emulated_to_float32, dst_to_float32)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + ('float8e5', 'float16'), + ('float8e5', 'bfloat16'), + ('float8e5', 'float32'), + + ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + ('float8e4b15', 'float32'), + + ('float8e4nv', 'float16'), + ('float8e4nv', 'bfloat16'), + ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'bfloat16'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + # On HIP, fp8e4nv upcasting to fp32 is only supported on CDNA4, and + # fp8e4nv upcasting to bf16 and fp16 is only supported on CDNA3 and CDNA4. + if is_cuda(): + if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9)) + or src_dtype in ('float8e4b8', 'float8e5b16')): + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return + elif is_hip(): + if (src_dtype == 'float8e4nv' and not (is_hip_cdna3() or is_hip_cdna4())): + pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") + if src_dtype == 'float8e4b15': + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_hip_cdna2() or is_hip_rdna4()): + pytest.skip(f"{src_dtype} is not supported on AMDGPU CDNA2 and RDNA4") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + ('float32', 'float8e5', 'rtne', 0x47600000), + ('float32', 'float8e5', 'rtz', 0x47600000), + ('float32', 'float8e4nv', 'rtne', 0x43e00000), + ('float32', 'float8e4b8', 'rtne', 0x43700000), + ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + ('bfloat16', 'float8e5', 'rtne', 0x4760), + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + ('float16', 'float8e5', 'rtne', 0x7b00), + ('float16', 'float8e4nv', 'rtne', 0x5f00), + + ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + ('float16', 'float8e5b16', 'rtne', 0x7b00), + ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if is_cuda(): + if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne': + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3") + + if is_hip(): + if dst_dtype in ('float8e4b8', 'float8e5b16') and (is_hip_cdna2() or is_hip_rdna4()): + pytest.skip(f"{dst_dtype} is not supported on AMDGPU CDNA2 and RDNA4") + + if is_musa(): + if is_musa_ph1() and dst_dtype in ('float8e4b8', 'float8e5b16'): + pytest.skip(f"{dst_dtype} is not supported on MUSA PH1") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) + +@pytest.mark.parametrize("mode", [ + 'max', 'min', 'inf', '-inf', 'nan', +]) +@pytest.mark.parametrize("dst_dtype", ["float8e4nv", "float8e5"]) +@pytest.mark.parametrize("src_dtype", ["float32", "float16", "bfloat16"]) +def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, device, rounding="rtne"): + if is_cuda(): + if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if mode in ('inf', '-inf') and is_hip_rdna4(): + pytest.skip(f"clamping from `{mode}` is not supported on AMDGPU GFX12") + + converter = { + tl.float8e4nv: torch.float8_e4m3fn, + tl.float8e5: torch.float8_e5m2, + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32 + } + + tl_src_dtype = getattr(tl, src_dtype) + tl_dst_dtype = getattr(tl, dst_dtype) + + torch_src_dtype = converter[tl_src_dtype] + torch_dst_dtype = converter[tl_dst_dtype] + + if mode in ('max', 'min'): + # Added to input to exceed the representation range to produce NaN + exceed_value = 100.0 + test_value = torch.finfo(torch_dst_dtype).max + exceed_value + expected_result = torch.finfo(torch_dst_dtype).max + elif mode in ('inf', '-inf'): + test_value = torch.inf + expected_result = torch.finfo(torch_dst_dtype).max + else: + assert mode == 'nan' + test_value = torch.nan + expected_result = torch.nan + + if mode in ('min', '-inf'): + test_value *= -1.0 + expected_result *= -1.0 + + BLOCK_SIZE = 1024 + shape = (BLOCK_SIZE * 2,) + src = torch.full(shape, test_value, dtype=torch_src_dtype, device=device) + dst = torch.empty(shape, dtype=torch_dst_dtype, device=device) + + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, torch_src_dtype), + triton.reinterpret(dst, torch_dst_dtype), + rounding, + BLOCK_SIZE + ) + + if mode == 'nan': + assert(torch.all(torch.isnan(dst))) + else: + # MUSA runtime currently cannot materialize float8 tensors via Fill, + # so build the float8 reference on CPU and compare there instead. + if is_musa() and torch_dst_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + expected = torch.full(dst.shape, expected_result, dtype=torch_dst_dtype) + torch.testing.assert_close(dst.cpu(), expected) + else: + torch.testing.assert_close(dst, torch.full_like(dst, expected_result)) diff --git a/third_party/mthreads/python/test/unit/language/test_core.py b/third_party/mthreads/python/test/unit/language/test_core.py new file mode 100644 index 0000000000..cace210187 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_core.py @@ -0,0 +1,6325 @@ +# ruff: noqa: F821,F841 +import contextlib +import itertools +import re +from typing import Optional +import math +import textwrap + +import numpy as np +import pytest +import torch +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +from triton._internal_testing import ( + integral_dtypes, + int_dtypes, + str_to_triton_dtype, + uint_dtypes, + float_dtypes, + float_dtypes_with_bfloat16, + dtypes, + dtypes_with_bfloat16, + is_cuda, + is_interpreter, + is_hopper, + is_hip, + is_hip_cdna, + is_hip_cdna2, + is_hip_cdna3, + is_hip_cdna4, + is_hip_rdna3, + is_hip_rdna4, + is_hip_gfx1250, + is_musa, + is_musa_ph1, + is_xpu, + get_arch, + torch_float8_dtypes, + torch_dtypes, + numpy_random, + to_triton, + torch_dtype_name, + to_numpy, +) +from triton.runtime.errors import InterpreterError + + +@contextlib.contextmanager +def promotion_numpy_2_0(): + state = np._get_promotion_state() + np._set_promotion_state("weak") + try: + yield + finally: + np._set_promotion_state(state) + + +# No need to emulate NumPy 2.0 if the user has NumPy 2.0 +if np.__version__[0] != "1": + promotion_numpy_2_0 = contextlib.nullcontext + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +mma_nonk_sizes = [] + +GPU_DIALECT = "ttg" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size + # for CDNA multiple variants of mma instructions are supported: + # mfma 16x16/mfma 32x32 + # 0 is a special value for automatic heuristic + if is_hip_cdna() or is_hip_gfx1250(): + mma_nonk_sizes = [0, 16, 32] + elif is_hip_rdna3() or is_hip_rdna4(): + mma_nonk_sizes = [16] +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def _dtype(dtype: str) -> str: + # ex.: "int64" -> "int" + return re.match(r'([a-zA-Z]+)', dtype).group(0) + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + src = kernel.src + for key, value in to_replace.items(): + src = src.replace(key, value) + kernel._unsafe_update_src(src) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda or HIP") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +def get_src_element_ty_size(dtype_str): + if dtype_str in ["int8", "uint8", "float8e4b15"]: + return 1 + if dtype_str == "float16": + return 2 + if dtype_str == "float32" or dtype_str == "tensorfloat32": + return 4 + if dtype_str == "float64": + return 8 + raise ValueError(f"Unknown dtype {dtype_str}") + + +@pytest.mark.interpreter +def test_scalar_overflow(device): + + @triton.jit + def kernel(): + huge_int: tl.constexpr = 0xFFFFFFFFFFFFFF + x = tl.full((), 32, dtype=tl.int32) + y = x + huge_int + + with pytest.raises(triton.TritonError, match="out of range"): + kernel[(1, )]() + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + # avoid log/sqrt of negative numbers + if 'log' in expr or 'sqrt' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, + test_scalar=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if filter_y: + y[filter_y(y)] = 1 + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + x_is_scalar = isinstance(x, (bool, int, float)) + y_is_scalar = isinstance(y, (bool, int, float)) + scalar_test = x_is_scalar or y_is_scalar + + # For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules. + if scalar_test: + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) + with promotion_numpy_2_0(): + z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if not scalar_test and dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + + # triton result + x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x) + y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) + + def get_scalar(x, dtype, low, high, filter): + # If dtype is int, don't choose a huge number for the scalar + # as it'll overflow easily when converted to the other dtype + if dtype in integral_dtypes: + # Choose in range [-7, 7] ([0, 7] for uints) + low_x = 0 if dtype in uint_dtypes else -7 + if low is not None: + low_x = max(low_x, low) + high_x = 7 + if high is not None: + high_x = min(high_x, high) + scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item() + if filter and filter(scalar): + # https://xkcd.com/221/ + scalar = 4 + else: + scalar = x.flat[0].item() + return scalar + + do_test(x, y, kernel) + if mode_y != 'nan' and test_scalar: + if dtype_x in uint_dtypes: + low = 0 if y_low is None else max(y_low, 0) + else: + low = y_low + y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y) + do_test(x, y_scalar, kernel_scalar_rhs) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _min_max_integral_mod_value(dtype_x, dtype_y) -> tuple[int, int]: + """ + Limit min/max values for integral types for mod values. Leads to + overflow/underflow when casting large integral types to floats. + """ + x_bitwidth = _bitwidth(dtype_x) + y_bitwidth = _bitwidth(dtype_y) + + # hard cap max value bit-width to 32 if 64 bit-width types + min_bitwidth = min(x_bitwidth, y_bitwidth, 32) + + # Limit max value bit-width to be one integral type less than the min bit-width + # For example: + # int64, float32 -> int16 + # uint16, float16 -> uint8 + x_dtype = _dtype(dtype_x) + max_bitwidth = max(min_bitwidth >> 1, 8) + dtype_max = x_dtype + str(max_bitwidth) + + max_info = np.iinfo(getattr(np, dtype_max)) + + # Still need to limit values here for uints + if max_bitwidth >= 16 and dtype_max in uint_dtypes: + return max_info.min, max_info.max // 4 + else: + return max_info.min, max_info.max + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') + else: + numpy_expr = None + + if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + # skip when bfloat16, as NumPy's ref performs the computation in float32 + # while Triton performs it in bfloat16 + skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) + # can't divide by zero + not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes + # can't represent -int(max) + not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes + if not_zero or not_minus_one: + filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) + else: + filter_y = None + + if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: + x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) + else: + x_low, x_high = None, None + + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + # can't represent -int(max) + not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes + if not_minus_one: + filter_y = lambda y: y == -1 + else: + filter_y = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, None:] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :None] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + t = data[None, None:None, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + t = data[None, None:None:None, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + t = data[None, ::None, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + t = data[None, None::None, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = False + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'rsqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + np_expr = f"1.0 / np.sqrt({x})" if expr == "rsqrt" else f"np.{expr}({x})" + _test_unary(dtype_x, f'tl.{expr}({x})', np_expr, device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device='cpu') + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x).to(device) + kernel[(1, )](z_tri, x.to(device), SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref.to(device)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + elif is_cuda(): + cc = torch.cuda.get_device_capability() + if in_dtype == tl.float8e4b15 and cc >= (9, 0): + pytest.skip("float8e4b15 not supported on CUDA >= 9.0") + if in_dtype == tl.float8e4nv and cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + elif is_musa(): + if in_dtype == tl.float8e4b15 and is_musa_ph1(): + pytest.skip("float8e4nv not supported on PH arch") + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).trans() + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.reshape(tl.arange(0, 64), 2, 4, 8, can_reorder=True) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y - 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'bfloat16', mode, sem), + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + check_type_supported(dtype_x_str, device) + if is_interpreter(): + if dtype_x_str == 'float16' or dtype_x_str == 'bfloat16': + pytest.skip("Only test atomic bfloat16/float16 ops on GPU") + if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]: + pytest.skip("uint cannot be negative") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16') else None + dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16') else dtype_x_str + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device, dst_type=dst_type) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device, dst_type=dst_type) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + if dst_type == 'bfloat16': + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + # atom.add.bf16 is unsupported prior to Hopper so instead we generate an + # atom.cas add loop on Ampere and prior + if dst_type == 'bfloat16' and torch.cuda.get_device_capability()[0] < 9: + assert f"atom.{sem_str}.gpu.global.cas" in h.asm["ptx"] + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str, check_return_val", + [(shape, axis, num_ctas, dtype_x_str, check_return_val) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64), (128, 128)] + for axis in [0, 1] + for num_ctas in num_ctas_list + for dtype_x_str in ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64'] + for check_return_val in ([True, False] if is_hip() else [True])]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device): + check_type_supported(dtype_x_str, device) + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr, + RETURN_VAL: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16 or DTYPE == tl.bfloat16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16 or DTYPE == tl.bfloat16: + z = z.to(DTYPE) + + if AXIS == 1: + old = tl.atomic_add(Z + off0, z) + if RETURN_VAL: + tl.store(OLD + off0, old) + else: + old = tl.atomic_add(Z + off1, z) + if RETURN_VAL: + tl.store(OLD + off1, old) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) + old = np.zeros(z_shape, dtype=z.dtype) + # reference results + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) + old_ref = np.copy(z) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x_str) + z_tri = to_triton(z, device=device, dst_type=dtype_x_str) + old_tri = to_triton(old, device=device, dst_type=dtype_x_str) + + def torch_to_triton_dtype(t): + if t == torch.bfloat16: + return tl.bfloat16 + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), check_return_val, + num_ctas=num_ctas) + + if dtype_x_str == 'bfloat16': + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + old_ref = (old_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + # mantissa trunc is not enough, bump up the relative tolerance as well + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.5) + # check return vals, but use assert_allclose for bf16 + if check_return_val: + np.testing.assert_allclose(old_ref, to_numpy(old_tri), rtol=0.5) + return + + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + if check_return_val: + np.testing.assert_equal(old_ref, to_numpy(old_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str) + for size in [2, 4, 8, 32, 64, 128] + for num_ctas in num_ctas_list + for dtype_x_str in ['bfloat16', 'float16', 'float32']]) +def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device): + check_type_supported(dtype_x_str, device) + + @triton.jit + def kernel(X, val, NUM: tl.constexpr): + off = tl.arange(0, NUM) + offset = off[:, None] * NUM + off[None, :] + val = tl.load(val + offset) + tl.atomic_add(X + offset // 2, val) + + shape = (size // 2, size) + dtype = getattr(torch, dtype_x_str) + x = torch.zeros(shape, dtype=dtype, device=device) + val = torch.randn((size**2), dtype=dtype, device=device) + kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas) + ref = val[0::2] + val[1::2] + torch.testing.assert_close(ref, x.reshape(math.prod(shape))) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str) + for size in [2, 4, 8, 32, 64, 128] + for num_ctas in num_ctas_list + for dtype_x_str in ['bfloat16', 'float16', 'float32']]) +def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device): + check_type_supported(dtype_x_str, device) + + @triton.jit + def kernel(X, val, NUM: tl.constexpr): + off_x = tl.arange(0, 2) + off_y = tl.arange(0, NUM) + off_in = off_x[:, None] * NUM + off_y[None, :] + off_out = off_x[:, None] + off_y[None, :] + + val = tl.load(val + off_in) + tl.atomic_add(X + off_out, val) + + s = (2, size) + dtype = getattr(torch, dtype_x_str) + x = torch.zeros(s, dtype=dtype, device=device) + ref = torch.flatten(x) + val = torch.randn(s, dtype=dtype, device=device) + kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas) + val = torch.flatten(val) + ref[0:size] = val[0:size] + ref[1:size + 1] += val[size:2 * size] + torch.testing.assert_close(ref, torch.flatten(x)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str", + [(shape, idx_order, mask_step, num_ctas, dtype_x_str) + for shape in [(2, 2), (4, 4), (5, 5), (6, 6), (8, 8)] + for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random'] + for mask_step in range(1, 5) + for num_ctas in num_ctas_list + for dtype_x_str in ['bfloat16', 'float16', 'float32']]) +def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device): + check_type_supported(dtype_x_str, device) + if is_interpreter(): + pytest.skip("not supported in the interpreter") + + @triton.jit + def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + x_idx = xoffset + tl.arange(0, XBLOCK)[:] + mask = x_idx < shape0 * shape1 + mask = mask & (x_idx % mask_step != 0) + idx_base = shape1 * (x_idx // shape1) + idx_offset = tl.load(idx_ptr + x_idx, mask) + in_elem = tl.load(in_ptr + x_idx, mask) + tl.atomic_add(out_ptr + (idx_offset + idx_base), in_elem, mask, sem='relaxed') + + shape0, shape1 = shape + idx_row = torch.arange(0, shape1, device=device) + if idx_order == 'increase': + idx = torch.stack([idx_row.repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'decrease': + idx = torch.stack([idx_row.flip(0).repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'random_no_duplication': + idx = torch.stack([torch.randperm(shape1, device=device) for _ in idx_row]) + if idx_order == 'random': + idx = torch.randint(0, shape1, size=(shape0, shape1), device=device) + + dtype = getattr(torch, dtype_x_str) + val = torch.randn((shape0, shape1), dtype=dtype, device=device) + dst = torch.randn((shape0, shape1), dtype=dtype, device=device) + + dst_ref = dst.clone() + + cnt = 0 + for i, row in enumerate(idx): + for j, elem in enumerate(row): + if cnt % mask_step != 0: + dst_ref[i][elem] += val[i][j] + cnt += 1 + + kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas) + + if dtype_x_str == 'bfloat16': + torch.testing.assert_close(dst_ref, dst, rtol=0.1, atol=0.1) + return + + np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("dtype_str", ["int32", "int64"]) +def test_atomic_cas(sem, num_ctas, dtype_str, device): + if is_hip_cdna2(): + pytest.skip("Disabled due to being flaky on CDNA2") + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock, triton_dtype: tl.constexpr): + num0 = tl.full((1, ), 0, dtype=triton_dtype).item() + num1 = tl.full((1, ), 1, dtype=triton_dtype).item() + tl.atomic_cas(Lock, num0, num1) + + torch_dtype = getattr(torch, dtype_str) + triton_dtype = getattr(tl, dtype_str) + Lock = torch.zeros((1, ), device=device, dtype=torch_dtype) + change_value[(1, )](Lock, triton_dtype) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, triton_dtype: tl.constexpr, SEM: tl.constexpr): + num0 = tl.full((1, ), 0, dtype=triton_dtype).item() + num1 = tl.full((1, ), 1, dtype=triton_dtype).item() + + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, num0, num1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # insert barrier to set a fence between tl.store and + # tl.atomic_xchg in a block. + tl.debug_barrier() + + # release lock + tl.atomic_xchg(Lock, num0) + + Lock = torch.zeros((1, ), device=device, dtype=torch_dtype) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, triton_dtype=triton_dtype, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, "acquire", "release", "acq_rel", "relaxed"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("size", [4, 128, 512, 1024]) +@pytest.mark.parametrize("dtype_str", ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64']) +def test_tensor_atomic_cas(sem, size, dtype_str, num_ctas, device): + check_type_supported(dtype_str, device) + if is_hip_cdna2(): + pytest.skip("Disabled due to being flaky on CDNA2") + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + torch_dtype = getattr(torch, dtype_str) + X = torch.zeros((size, ), device=device, dtype=torch_dtype) + X[1::2] = 1 + if device == "musa" and torch_dtype is torch.uint64: + # torch_musa uint64 clone/copy_ currently errors; keep the oracle on CPU. + Y = X.to('cpu').clone() + else: + Y = X.clone() + Y[0::2] = 2 + + tl_dtype = getattr(tl, dtype_str) + change_value[(2, )](X, BLOCK_SIZE=size // 2, sem=sem, dtype=tl_dtype) + assert torch.equal(X.to('cpu'), Y.to('cpu')) + + +@pytest.mark.interpreter +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, + reason="Requires compute capability >= 9 for NV") +def test_load_scope_sem_coop_grid_cta_not_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True) + kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False) + + +@pytest.mark.interpreter +def test_load_scope_sem_coop_grid_cta_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + # Should do nothing different for num_ctas=1 (with coop launch grid) + kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True) + kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False) + + +@pytest.mark.interpreter +def test_atomic_min_max_neg_zero(device): + + @triton.jit + def kernel(inp, out_max, out_min): + idx = tl.program_id(0) + x = tl.load(inp + idx) + tl.atomic_max(out_max + idx, x) + tl.atomic_min(out_min + idx, x) + + N_PROG = 1 + dtype = torch.float32 + out_min = torch.full([N_PROG], torch.finfo(torch.float32).max, device=device, dtype=dtype) + out_max = torch.full([N_PROG], torch.finfo(torch.float32).min, device=device, dtype=dtype) + inp = torch.full([N_PROG], -0.0, device=device, dtype=dtype) + kernel[(N_PROG, )](inp, out_max, out_min) + torch.testing.assert_close(out_min, inp, atol=0, rtol=0) + torch.testing.assert_close(out_max, inp, atol=0, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "int8", "int16", "uint8", "uint16"]) +def test_atomic_unsupported_type(dtype_str, device): + + @triton.jit + def kernel(I, O): + x = tl.load(I) + tl.atomic_add(O, x) + + I = torch.zeros((1, ), device='cpu', dtype=getattr(torch, dtype_str)).to(device) + O = torch.zeros((1, ), device='cpu', dtype=getattr(torch, dtype_str)).to(device) + with pytest.raises(triton.TritonError): + kernel[(1, )](I, O) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "float16"]) +@pytest.mark.parametrize("size", [1, 4, 16]) +@pytest.mark.parametrize("op", ["add", "cas"]) +def test_tensor_atomic_use_result(dtype_str, size, op, device): + if is_hip(): + pytest.skip( + "HIP is broken because (1) it doesn't support thread predicate in atomic cas, and (2) it doesn't support" + " atomic rmw with float16") + + @triton.jit + def kernel(index_ptr, out_ptr, size: tl.constexpr, op: tl.constexpr): + if op == "add": + write_index = tl.atomic_add(index_ptr + tl.arange(0, size)[:, None], val=tl.arange(0, size)[:, None], + sem="relaxed") + elif op == "cas": + write_index = tl.atomic_cas( + index_ptr + tl.arange(0, size)[:, None], + cmp=tl.zeros((size, ), dtype=index_ptr.dtype.element_ty)[:, None], + val=tl.arange(0, size).to(index_ptr.dtype.element_ty)[:, None], + sem="relaxed", + ) + tl.store(out_ptr + write_index.to(tl.uint32) * size + tl.arange(0, size)[None, :], 5) + + index = torch.arange(0, size, device=device).to(dtype=getattr(torch, dtype_str)) + out = torch.zeros((size, size), device=device, dtype=getattr(torch, dtype_str)) + kernel[(1, )](index, out, size, op) + assert (out == 5).all() + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'bool', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip(): + if not is_hip_cdna3() and not is_hip_cdna4() and not is_hip_gfx1250() and (dtype_x == 'float8_e4m3fn' + or dtype_z == 'float8_e4m3fn'): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} only supported on HIP CDNA3/CDNA4 and above.') + if (not (is_hip_cdna4() or is_hip_gfx1250())) and ((dtype_x == 'bfloat16' and dtype_z == "float8_e4m3fn") or + (dtype_x == "float8_e4m3fn" and dtype_z == 'bfloat16')): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} only supported on HIP CDNA4 and above.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 4 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 2: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + + dtype_z_tri = str_to_triton_dtype(dtype_z) + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, + num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +@pytest.mark.parametrize("can_reorder", [True, False]) +def test_cat(dtype_str, num_warps, can_reorder, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr, CAN_REORDER: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=CAN_REORDER) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0) + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps, CAN_REORDER=can_reorder) + if device == "musa": + z_ref = z_ref.to('cpu') + z = z.to('cpu') + assert z.sum() == z_ref.sum() + if not can_reorder: + torch.testing.assert_close(z, z_ref, atol=0, rtol=0) + # check if there's no duplicate value in z + z = z.to('cpu') + assert z.unique().size(0) == z.size(0) + + +CAT_ND_SHAPES = ((128, ), (16, 32), (8, 16, 4), (2, 4, 8, 16)) +CAT_ND_CASES = [] +for shape in CAT_ND_SHAPES: + for dim in range(len(shape)): + CAT_ND_CASES.append(pytest.param(shape, dim, id=f"rank={len(shape)},dim={dim}")) + + +@pytest.mark.parametrize("shape, dim", CAT_ND_CASES) +def test_cat_nd(shape, dim, device): + + @triton.jit + def kernel(x_desc, y_desc, z_desc, dim: tl.constexpr, shape: tl.constexpr): + rank: tl.constexpr = len(shape) + x = x_desc.load([0] * rank) + y = y_desc.load([0] * rank) + z = tl.cat(x, y, dim=dim) + z_desc.store([0] * rank, z) + + x = torch.rand(shape, device=device) + y = torch.rand(shape, device=device) + z_ref = torch.cat([x, y], dim=dim) + z = torch.empty_like(z_ref) + x_desc = TensorDescriptor.from_tensor(x, block_shape=shape) + y_desc = TensorDescriptor.from_tensor(y, block_shape=shape) + z_desc = TensorDescriptor.from_tensor(z, block_shape=z_ref.shape) + kernel[(1, )](x_desc, y_desc, z_desc, dim=dim, shape=shape) + torch.testing.assert_close(z, z_ref, atol=0, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(num_ctas, dtype_str, constant_field, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) + + if constant_field == "value": + assert torch.all(output == ref) + else: + assert torch.all(output == 0) + + +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + if dtype == tl.float8e4nv and fp.device.type == "musa": + fp_bits = fp.to(torch.int16) & 0xFF + output[fp_bits == 0x7F] = torch.nan + output[fp_bits == 0xFF] = torch.nan + else: + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +@pytest.mark.interpreter +def test_max_min_with_nan(device): + # In triton, we implement a "nan ignore" style, which means if there is NaN + # in the reduce dimesion, we should ignore it and return the max/min number, + # it's different with torch.max/min. + @triton.jit + def max_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offsets) + + max_val = tl.max(x, axis=0) + + if tl.program_id(0) == 0: + tl.store(y_ptr, max_val) + + @triton.jit + def min_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offsets) + + min_val = tl.min(x, axis=0) + + if tl.program_id(0) == 0: + tl.store(y_ptr, min_val) + + BLOCK_SIZE = 64 + x = torch.rand((1, BLOCK_SIZE), dtype=torch.float32, device=device) + # Not the expected output for tl.max + x[0, 0] = float('nan') + # Expected output for tl.min + x[0, 1] = float('-inf') + # Expected output for tl.max + x[0, 2] = float('inf') + + y = torch.ones(1, device=device) + + max_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE) + assert y[0] == float('inf') + + min_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE) + assert y[0] == float('-inf') + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +def get_reduce_input(dtype_str, shape): + # limit the range of integers so that reduce ops do not overflow + low = 0 if dtype_str in uint_dtypes else -10 if dtype_str in integral_dtypes else None + high = 10 if dtype_str in integral_dtypes else None + return numpy_random(shape, dtype_str=dtype_str, low=low, high=high) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + x = get_reduce_input(dtype_str, (shape, )) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = x[numpy_op(x)] + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str + z_tri_dtype_str = z_dtype_str + if 'tie-break-left' not in op and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if 'tie-break-left' in op: + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] +reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs + reduce_bool) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + if USE_I1: + x = tl.cast(x, tl.int1) + z = GENERATE_TEST_HERE + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + x = get_reduce_input(dtype_str, shape) + x_tri = to_triton(x, device=device) + numpy_op = { + 'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum': + np.bitwise_xor.reduce + }[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + if z_dtype_str == 'bool': + z_dtype_str = 'int8' + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + USE_I1 = dtype_str == 'bool' + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +def test_sum_dtype(device): + + @triton.jit + def kernel_dtype(out_ptr, init, in_dtype: tl.constexpr, out_dtype: tl.constexpr): + x = tl.full((32, 32), init, dtype=in_dtype) + x = tl.sum(x, dtype=out_dtype) + tl.store(out_ptr, x.to(tl.int32)) + + @triton.jit + def kernel_default_int(out_ptr): + x = tl.full((32, 32), 1, dtype=tl.int1) + x = tl.sum(x) + tl.store(out_ptr, x) + + @triton.jit + def kernel_default_float(out_ptr): + x = tl.full((32, 32), 1.0, dtype=tl.bfloat16) + x = tl.sum(x) + tl.store(out_ptr, x) + + out = torch.empty(1, dtype=torch.int32, device=device) + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=tl.int1) + assert out[0] == 0 + + kernel_dtype[(1, )](out, init=7, in_dtype=tl.int8, out_dtype=tl.int8) + assert out[0] == (7 * 32 * 32) % 256 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int32, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_default_int[(1, )](out) + assert out[0] == 32 * 32 + + out = torch.empty(1, dtype=torch.bfloat16, device=device) + kernel_default_float[(1, )](out) + torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device=device)) + + +# trivial associative but not commutative function +@triton.jit +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not supported before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = np.empty_like(x, dtype=np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device='cpu', dtype=torch.int32).to(device) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.interpreter +def test_histogram_silent_data_corruption(device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr): + offset = tl.arange(0, 1) + x = tl.load(x_ptr + offset) + z = tl.histogram(x, 1) + tl.store(z_ptr + offset, z) + + x = torch.ones(1, device=device, dtype=torch.int32) + z = torch.ones(2, device=device, dtype=torch.int32) + + histogram_kernel[(1, )](x, z) + assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}" + + +# ------------------------ +# test histogram with mask +# ------------------------ + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram_mask(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, 2 * M) + offset2 = tl.arange(0, N) + mask = offset1 < M + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N, mask) + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x1 = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + x = torch.cat((x1, x1), 0) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x1.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)]) +def test_scan_1d(M, N, device): + + @triton.jit + def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr): + input = tl.load(in_ptr + tl.arange(0, M)) + output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M]) + tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N])) + + x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device=device) + output = torch.empty(M * N, dtype=torch.int32, device=device) + + scan_kernel[(1, )](output, x, M, N) + + ref = torch.cumsum(x, dim=0).reshape([1, M]).broadcast_to([N, M]).reshape([M * N]) + torch.testing.assert_close(ref.to(torch.int32), output, atol=0, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_n = tl.num_programs(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * num_pid_n + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +def test_no_rematerialization_op(device): + + if torch.version.hip: + pytest.skip("test not supported on AMD") + + @triton.jit + def kernel( + input_data, + sum_output, + out_1, + BLOCK_SIZE: tl.constexpr, + DATA_DIM: tl.constexpr, + DATA_LEN: tl.constexpr, + loop_stages: tl.constexpr, + ): + tl.static_assert(DATA_LEN % BLOCK_SIZE == 0) + for curr_block_idx in tl.range(0, DATA_LEN // BLOCK_SIZE, num_stages=loop_stages): + my_idxs = BLOCK_SIZE * curr_block_idx + tl.arange(0, BLOCK_SIZE) + values = tl.load(input_data + DATA_DIM * my_idxs[:, None] + tl.arange(0, DATA_DIM)[None, :]) + accum = tl.sum(values, axis=-1).to(tl.float32) + tl.store(sum_output + my_idxs, accum) + sum_plus_0 = tl.full((1, 2), 0, tl.float32) + accum[:, None] + tl.store(out_1 + my_idxs[:, None] * 2 + tl.arange(0, 2)[None, :], sum_plus_0) + + data_len = 32 + data_dim = 64 + torch.manual_seed(0) + input_data = torch.randn((data_len, data_dim), dtype=torch.float32, device='cpu').to(device) + sum_output = torch.full((data_len, ), -1, dtype=torch.float32, device='cpu').to(device) + out_1 = torch.full((data_len, 2), -1, dtype=torch.float32, device='cpu').to(device) + compiled_kernel = kernel.warmup( + input_data=input_data, + sum_output=sum_output, + out_1=out_1, + DATA_DIM=data_dim, + DATA_LEN=data_len, + BLOCK_SIZE=16, + num_warps=1, + loop_stages=2, + grid=(1, ), + ) + assert compiled_kernel.asm["ttgir"].count('"tt.reduce"') == 1, "we shouldn't rematerialize tt.reduce" + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def _sum_combine(a, b): + return a + b + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, out_sum0, out_sum1, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + # Test return a tuple and a single value + sum0, = tl.reduce((x, ), 0, _sum_combine) + sum1 = tl.reduce(x, 0, _sum_combine) + # Test multiple values in a tuple + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + tl.store(out_sum0, sum0) + tl.store(out_sum1, sum1) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + sum0 = torch.empty((), device=device) + sum1 = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, sum0, sum1, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + sum_ref = torch.sum(x) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + torch.testing.assert_close(sum0, sum_ref) + torch.testing.assert_close(sum1, sum_ref) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if dtype_str == "float8e4b15" and (is_musa() or is_hip() or + (is_cuda() and torch.cuda.get_device_capability() >= (9, 0))): + pytest.skip("float8e4b15 not supported on ROCm or CUDA >= 9.0") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + if dtype_str == 'float8e4b15': + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + # numpy result + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device='cpu').to(device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device='cpu').to(device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + elif dtype_str == 'float8e4b8': + return torch.tensor(x, device=device).view(torch.float8_e4m3fnuz).to(torch.float32) + elif dtype_str == 'float8e5b16': + return torch.tensor(x, device=device).view(torch.float8_e5m2fnuz).to(torch.float32) + raise AssertionError("Unsupported float8 dtype") + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_base_cases(): + return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', + 'float32'), ('float32', + 'float32'), ('float64', 'float64')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_softmax(): + return [(128, 128, 64, 8, False, False, 'softmax', 'ieee', 'float16', 'float32', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_mixed_sizes_cases(): + available_kpack = [1, 2 if (is_hip() and not is_hip_cdna4()) else 1] + available_precision = ["tf32" if is_cuda() else "ieee"] + return [ + (*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack, None) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in available_precision + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', + 'float32'), ('float32', 'float32')] + for kpack in available_kpack + ] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #2370 +def get_test_dot_transposed_op_base_cases(): + return [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1, None) + for col_a in [True, False] + for col_b in [True, False]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# Introduced in #2750 +def get_test_dot_h100_shortcut_cases(): + return [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3908 +def get_test_dot_mfma_edge_cases(): + if not (is_hip_cdna() or is_hip_gfx1250()): + return [] + return [(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1, None), + (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3370 +def get_test_dot_fp8_output_cases(): + return [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1, None) + for float8_type in ["float8e5", "float8e4nv"]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #5406 +def get_test_dot_small_k_mfma_cases(): + if not (is_hip_cdna() or is_hip_gfx1250()): + return [] + return [(32, 32, k_size, 4, False, False, 'None', 'ieee', in_dtype, out_dtype, 1, mma_nonk_size) + for k_size in [1, 2, 4, 8] + for in_dtype, out_dtype in [('float16', 'float32'), ('int8', 'int32')] + for mma_nonk_size in mma_nonk_sizes] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #4516 +def get_test_dot_small_mn_mfma_cases(): + if not (is_hip_cdna() or is_hip_gfx1250()): + return [] + return [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None) + for shape_nw in [(4, 64, 64, 1), (64, 4, 64, 1)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]] + + +def get_test_dot_double_rate_cases(): + if not (is_hip_cdna() or is_hip_gfx1250()): + return [] + return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] + + +def get_test_dot_vdot2_cases(): + if not (is_hip_cdna() or is_hip_gfx1250()): + return [] + return [(4, 32, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (4, 32, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] + + +def get_test_small_dots_cases(): + if not is_cuda(): + return [] + return [(2, 4, 32, 1, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size", + get_test_dot_vdot2_cases() + \ + get_test_dot_double_rate_cases() + \ + get_test_dot_base_cases() + \ + get_test_dot_mixed_sizes_cases() + \ + get_test_dot_transposed_op_base_cases() + \ + get_test_dot_h100_shortcut_cases() + \ + get_test_dot_mfma_edge_cases() + \ + get_test_dot_fp8_output_cases() + \ + get_test_dot_small_k_mfma_cases() + \ + get_test_dot_small_mn_mfma_cases() + \ + get_test_dot_softmax() + \ + get_test_small_dots_cases()) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size, + num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + if input_precision == "bf16x3" or input_precision == "bf16x6": + pytest.skip(f"input_precision {input_precision} is not supported in the interpreter") + else: + if not is_hip() and K < 16: + pytest.skip("small dots are supported only on HIP at the moment") + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + if in_dtype == 'float64' and input_precision != 'ieee': + pytest.skip("Only IEEE precision is supported for float64 dot") + + if is_hip(): + if in_dtype in ("float8e5", "float8e4nv") and not (is_hip_gfx1250() or is_hip_cdna4() or is_hip_rdna4()): + pytest.skip(f"{in_dtype} only supported on CDNA4, RDNA4 and above") + if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): + pytest.skip(f"{in_dtype} only supported on CDNA3") + if not ((input_precision in ("bf16x3", "bf16x6")) or (input_precision == "ieee") or + (input_precision == "tf32" and is_hip_cdna3())): + pytest.skip(f"{input_precision} not supported on HIP") + if kpack == 2 and in_dtype == 'int8' and K < 64: + pytest.skip("kpack too large for K") + if in_dtype == 'float64': + pytest.skip("float64 not supported on HIP yet") + + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + z_max = tl.max(z, 1) + z = z - z_max[:, None] + num = tl.exp(z.to(tl.float32)).to(z_max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision in ["tf32", "bf16x3", "bf16x6"]: + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + if mma_nonk_size is not None: + kern_kwargs['matrix_instr_nonkdim'] = mma_nonk_size + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + # Reduce z_ref's precision to fp8 to match the kernel behavior + if in_dtype == 'float8e4nv': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn) + elif in_dtype == 'float8e5': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2) + elif in_dtype == 'float8e4b8': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fnuz) + elif in_dtype == 'float8e5b16': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2fnuz) + else: + raise AssertionError("Unsupported float8 dtype") + z_ref = to_numpy(z_fp8.to(torch.float32)) + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + if not (is_cuda() or is_hip_cdna() or is_hip_gfx1250()): + return + + if is_hip_cdna() or is_hip_gfx1250(): + amdgcn = pgm.asm['amdgcn'] + + if is_hip_cdna() and ((M, N) == (4, 64) or (M, N) == (64, 4)): + assert 'v_mfma_f32_4x4' in amdgcn + elif is_hip_cdna() and (M, N) == (4, 32): + if in_dtype == 'float16': + assert 'v_dot2c_f32_f16' in amdgcn + elif (in_dtype == 'bfloat16') and (is_hip_cdna4() or is_hip_gfx1250()): + assert 'v_dot2c_f32_bf16' in amdgcn + return + + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + if 'float64' in in_dtype: + assert 'ld.global.v2.b64' in ptx + else: + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + elif 'float64' in in_dtype: + assert 'st.global.v2.b64' in ptx + else: + assert 'st.global.v4' in ptx + + is_tcgen5 = (capability[0] == 10) and (num_warps % 4) == 0 and (M % 64) == 0 and (N % 8) == 0 + + if in_dtype == 'float32' and input_precision != "ieee": + if is_tcgen5: + if input_precision in ("bf16x3", "bf16x6"): + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + else: + assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) + elif input_precision in ("bf16x3", "bf16x6"): + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9 and M >= 64 and N >= 8: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif capability[0] >= 8 and M < 64: + assert 'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9 and M >= 64 and N >= 8: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + if is_tcgen5 and epilogue == 'softmax' and M >= 128: + # check that there is no shared memory exchange in the softmax + pattern = (r'tcgen05\.ld\.sync\.aligned\.16x32bx2\.x64\.b32' + r'(?:(?!st\.shared).)*' + r'cvt\.rn\.f16x2\.f32') + assert re.search(pattern, ptx, flags=re.DOTALL) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device='cpu').to(device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static.to('cpu') == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device='cpu').to(device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic.to('cpu') == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched.warmup(out, grid=(1, )) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel.warmup(input, output, output, choose_const, SIZE, SIZE, grid=(1, )) + if constexpr: + error = "Cannot store to a constant pointer" + else: + if mode == "call": + error = "Return type mismatch: " + elif mode == "if": + error = "Mismatched type for final_out" + elif mode == "ternary": + error = "Ternary expression with dynamic condition has inconsistent type" + else: + assert mode == "direct" and choose_const + error = "Cannot store to a constant pointer" + error_msg = exc_info.value.error_message or str(exc_info.value.__cause__) + assert error in error_msg, "Wrong error message!" + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device='cpu').to(device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with a copy to shared memory. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cg_cache_modifier_str = 'nt' + cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] + global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + load_line = global_load_line[0] if global_load_line else buffer_load_line[0] + if cache == '' or cache == '.ca': + assert cg_cache_modifier_str not in load_line + if cache == '.cg': + assert cg_cache_modifier_str in load_line + if cache == '.cv': + assert cv_cache_modifier_str in load_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.randn(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +@pytest.mark.interpreter +def test_assume(device): + + @triton.jit + def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= 128: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + output = torch.zeros(1024 // 128, device=device) + pgm = _kernel[(1024 // 128, )](output, N=1024, BLOCK_N=128) + + if is_interpreter(): + return + + assert 'llvm.intr.assume' in pgm.asm['ttgir'] + # tritonamdgpu-fold-true-cmpi on AMD folds true cmpi ops to %true (which llvm itself then DCEs). + if not is_hip(): + assert 'llvm.assume' in pgm.asm['llir'] + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cs_cache_modifier_str = 'nt' + wt_cache_modifier_str = 'sc0 sc1' + buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line] + global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] + store_line = global_store_line[0] if global_store_line else buffer_store_line[0] + if cache == '' or cache == '.cg': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.cs': + assert cs_cache_modifier_str in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.wt': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str in store_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) +def test_store_eviction_policy(eviction_policy, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, POLICY: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, eviction_policy=POLICY) + + pgm = _kernel[(1, )](dst, src, POLICY=eviction_policy) + + if not is_cuda(): + return + ptx = pgm.asm['ptx'] + if eviction_policy == '': + assert 'evict_last' not in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_last': + assert 'evict_last' in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_first': + assert 'evict_last' not in ptx + assert 'evict_first' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + if is_musa() and device == "cuda": + device = "musa" + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if isinstance(k, str) and v == 1]) + return f"kernel_{ty}_{cst}" + + @triton.jit(repr=repr) + def kernel(value1, is_one, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel.warmup(value, 1, x, grid=(1, )) + assert "is_one" in h.name + assert value_type in h.name + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel.warmup(grid=(1, )) + + assert "reshape" in str(exc_info.value) + + +@pytest.mark.interpreter +def test_tma_load_block_shape_err(device): + + @triton.jit + def kernel(ptr): + desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 2]) + desc.load([0, 0]) + + input = torch.empty((128, 128), dtype=torch.int32, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__) + + +@pytest.mark.interpreter +def test_tma_store_block_shape_err(device): + + @triton.jit + def kernel(ptr): + desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 4]) + desc.store([0, 0], tl.zeros([8, 4], dtype=tl.int16)) + + input = torch.empty((128, 128), dtype=torch.int16, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__) + + +def test_trans_reshape(device, with_allocator): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticValue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticValue != 0 and StaticValue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel.warmup(dst=dst, grid=(1, ), num_warps=3) + _kernel.warmup(dst=dst, grid=(1, ), num_warps=1) + _kernel.warmup(dst=dst, grid=(1, ), num_warps=2) + _kernel.warmup(dst=dst, grid=(1, ), num_warps=4) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test map elementwise +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_map_elementwise(num_ctas, device): + + @triton.jit + def compare(x, y): + if x < y: + return -1 + elif x == y: + return 0 + else: + return 1 + + @triton.jit + def kernel(X, Y, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + z = tl.map_elementwise(compare, x, y) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='int32', rs=rs) + y = numpy_random(shape, dtype_str='int32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK=shape[0], num_ctas=num_ctas) + z_ref = (x > y).astype(int) - (y > x).astype(int) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +def test_map_elementwise_multiple_outputs(device): + + @triton.jit + def divmod(a, b): + return a // b, a % b + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + c, d = tl.map_elementwise(divmod, a, b) + + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A // B + D_ref = A % B + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_map_elementwise_pack(device): + + @triton.jit + def divmod(a0, a1, b0, b1): + return a0 // b0, a1 // b1, a0 % b0, a1 % b1 + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + c, d = tl.map_elementwise(divmod, a, b, pack=2) + + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + h = kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A // B + D_ref = A % B + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +def test_constexpr_if_return(device): + # Reproducer for #4883, return statement in an if with a constexpr causes + # errors when combined with non-trivial control flow graphs + + @triton.jit + def kernel(Semaphore, Out, total: tl.constexpr): + if total == 1: + tl.store(Out, tl.program_id(0)) + return + + prev = tl.atomic_add(Semaphore, 1) + if prev + 1 != total: + return + + tl.store(Out, tl.program_id(0) + prev) + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.empty((), device=device, dtype=torch.int32) + kernel[(1, )](sem, out, 1) + assert out.item() == 0 + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.full((), fill_value=-1, device=device, dtype=torch.int32) + kernel[(4, )](sem, out, 4) + assert out.item() >= 0 + + +def test_constexpr_flattens(): + assert tl.constexpr(tl.constexpr(5)) == tl.constexpr(5) + assert tl.constexpr(tl.constexpr(tl.constexpr(5))) == tl.constexpr(5) + + +@pytest.mark.parametrize("literal, tensor_ty", [(10, tl.int32), (32.1, tl.float32), + ((5, 6, 7), None), # tuples can't be lifted to tensors + ]) +def test_constexpr_assignment(literal, tensor_ty): + from triton.language.core import constexpr_type + + @triton.jit + def kernel(input_literal: tl.constexpr, tensor_type: tl.constexpr): + patched_literal: tl.constexpr = PATCHED + # Sanity checks + tl.static_assert(patched_literal.type == constexpr_type(PATCHED)) + tl.static_assert(input_literal.type == constexpr_type(PATCHED)) + + assigned_literal: tl.constexpr = input_literal + tl.static_assert(assigned_literal.type == constexpr_type(PATCHED)) + tl.static_assert(assigned_literal == patched_literal) + + if tensor_type is not None: + assigned_variable = input_literal + tl.static_assert(assigned_variable.type == tensor_type) + + kernel_patched = patch_kernel(kernel, {'PATCHED': f"{literal}"}) + kernel_patched[(1, )](literal, tensor_ty) + + +def test_constexpr_arg_str_attr(): + + @triton.jit + def cst_str_attr(c_s_arg: tl.constexpr): + pass + + cst_str_attr.warmup('SD', grid=(1, )) + + +@triton.jit +def return_poison(x): + a = False + if a: + return x + + +def test_poison_return(device): + + @triton.jit + def kernel(Out): + zero = 0 + tl.store(Out, return_poison(zero)) + + a = torch.empty((), device=device, dtype=torch.int32) + h = kernel.warmup(a, grid=(1, )) + assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] + # hip/xpu uses llvm.store, which in this case is removed by the optimizer + if not (is_hip() or is_xpu() or is_musa()): + assert "poison" in h.asm["llir"], h.asm["llir"] + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + if is_musa(): + pytest.skip("test_num_threads is not supported in MUSA") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +def test_globaltimer(device): + check_cuda_or_hip(device) + if is_hip(): + pytest.skip("test_globaltimer is flaky on AMD GPUs") + + @triton.jit + def kernel(Out1, Out2, func: tl.constexpr): + start = func() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = func() + tl.store(Out2, start) + tl.store(Out2 + 1, end) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((2, ), dtype=np.int64), device=device) + if is_cuda(): + func = tl.extra.cuda.globaltimer + else: + func = tl.extra.hip.memrealtime + h = kernel[(1, )](out1, out2, func) + assert out2[1] - out2[0] > 0 + if is_cuda(): + assert h.asm["ptx"].count("%globaltimer") == 2 + else: + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch or "gfx12" in target_arch: + assert h.asm["amdgcn"].count("s_sendmsg_rtn_b64") == 2 + else: + assert h.asm["amdgcn"].count("s_memrealtime") == 2 + + +def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + if is_musa(): + pytest.skip("test_ptx_cast no longer needs to be tested in MUSA") + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +_DOT_MAX_NUM_IMPRECISE_ACC_IN_TYPES = ["float8e5", "float8e4nv"] +if is_hip(): + _DOT_MAX_NUM_IMPRECISE_ACC_IN_TYPES += ["float8e5b16", "float8e4b8"] +elif is_cuda(): + _DOT_MAX_NUM_IMPRECISE_ACC_IN_TYPES += ["float8e4b15"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) +@pytest.mark.parametrize( + "in_type_str", + _DOT_MAX_NUM_IMPRECISE_ACC_IN_TYPES, +) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): + num_stages = 3 + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + elif is_hip(): + num_stages = 2 + if in_type_str in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): + pytest.skip(f"{in_type_str} only supported on CDNA3") + if in_type_str in ("float8e5", "float8e4nv") and not (is_hip_cdna4() or is_hip_rdna4() or is_hip_gfx1250()): + pytest.skip(f"{in_type_str} only supported on CDNA4, RDNA4 and above") + + check_type_supported(in_type_str, device) + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None + h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, + num_stages=num_stages) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + else: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_hopper() and low_precision_acc > 0: + # Hopper-specific workaround lower precision accumulator. + assert h.asm["ptx"].count("add.f32") == (BLOCK_M * BLOCK_N) // (32 * num_warps) * (BLOCK_K // low_precision_acc) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +@pytest.mark.parametrize("default_override", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, default_override, device, fresh_knobs): + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + if default_override: + fresh_knobs.language.default_fp_fusion = enable_fp_fusion + h = mul_add.warmup(data, grid=(1, )) + else: + h = mul_add.warmup(data, grid=(1, ), enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test enable_reflect_ftz +# ----------------------- + + +@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") +@pytest.mark.parametrize("enable_reflect_ftz", [False, True]) +def test_enable_reflect_ftz(enable_reflect_ftz, device, fresh_knobs): + + @triton.jit + def exp2(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.math.exp2(tl.load(ptrs))) + + data = torch.full((128, ), -127.0, device=device, dtype=torch.float32) + h = exp2.warmup(data, grid=(1, ), enable_reflect_ftz=enable_reflect_ftz) + + found_ex2_ftz = re.search(r'ex2.approx.ftz.f32', h.asm["ptx"]) is not None + assert found_ex2_ftz == enable_reflect_ftz + + +# ----------------------- +# test override_arch +# ----------------------- + + +@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90", "gfx942", "gfx950", "gfx1200"]) +@pytest.mark.parametrize("env_var_override", [False, True]) +def test_override_arch(arch, env_var_override, device, fresh_knobs): + if arch.startswith("sm") and not is_cuda(): + pytest.skip(f"{arch} arch only for CUDA") + elif arch.startswith("gfx") and not is_hip(): + pytest.skip(f"{arch} arch only for HIP") + + @triton.jit + def simple(data, out): + in_ptrs = data + tl.arange(0, 128) + out_ptrs = out + tl.arange(0, 128) + tl.store(out_ptrs, tl.load(in_ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + out = torch.empty_like(data) + + if is_cuda(): + if env_var_override: + fresh_knobs.runtime.override_arch = str(arch) + h = simple.warmup(data, out, grid=(1, )) + else: + h = simple.warmup(data, out, arch=arch, grid=(1, )) + ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"]) + assert ttgir_cc.group(1) == arch[2:] + elif is_hip(): + # For HIP, the generated kernel is a binary containing the final ISA. So we cannot run + # them like CUDA side if the chip doesn't match. Here we just check generated ISA. + if env_var_override: + fresh_knobs.runtime.override_arch = str(arch) + h = simple.warmup(data, out, grid=(1, )) + else: + h = simple.warmup(data, out, arch=arch, grid=(1, )) + ttgir_gfx = re.search(r'hip:(\w+)', h.asm["ttgir"]) + ttgir_warp = re.search(r'"ttg.threads-per-warp" = (\d+)', h.asm["ttgir"]) + amdgcn_gfx = re.search(r'.amdgcn_target "amdgcn-amd-amdhsa--(\w+)"', h.asm["amdgcn"]) + assert ttgir_gfx.group(1) == arch + assert int(ttgir_warp.group(1)) == (32 if arch == "gfx1200" else 64) + assert amdgcn_gfx.group(1) == arch + + +def test_num_ctas_pre_sm90(device, fresh_knobs): + if not is_cuda() and not is_hip() and not is_musa(): + pytest.skip("Only supported on CUDA, HIP, and MUSA") + + @triton.jit + def _kernel(src): + pass + + src = torch.empty(1, device=device) + if is_cuda(): + arch = "sm80" + msg = r"num_ctas > 1 requires NVIDIA SM90\+ \(Hopper\)" + elif is_hip(): + arch = "gfx942" + msg = r"num_ctas > 1 not supported" + else: + arch = "ph1" + msg = r"num_ctas > 1 requires MUSA cluster launch support" + + fresh_knobs.runtime.override_arch = str(arch) + with pytest.raises(ValueError, match=msg): + _kernel.warmup(src, grid=(1, ), num_ctas=2) + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + _min = tl.load(min_ptr + off, mask=mask) + _max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, _min, _max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, _min), _max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + _min = torch.min(a, b) + _max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, _min, _max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['bfloat16', 'float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.interpreter +def test_tl_range_num_stages(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group \t6' in ptx + + +def test_tl_range_fuse(device): + + @triton.jit + def kernel(ub, out_ptr): + k = 1 + for i in tl.range(0, ub, flatten=True): + for j in tl.range(0, ub): + tl.store(out_ptr + i * 32 + j, k) + k += 1 + + ub = 10 + out = torch.zeros((32, 32), dtype=torch.int32, device=device) + compiled_kernel = kernel[(1, )](ub, out) + assert "tt.flatten" in compiled_kernel.asm["ttir"] + assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 + + ref = torch.zeros((32, 32), dtype=torch.int32, device=device) + k = 1 + for i in range(ub): + for j in range(ub): + ref[i, j] = k + k += 1 + torch.testing.assert_close(out, ref, atol=0, rtol=0) + + +def test_tl_range_fuse_dependent(device): + + @triton.jit + def kernel(ub, out_i_ptr, out_j_ptr): + k = 0 + for i in tl.range(0, ub, flatten=True): + lower_bound = i * 2 + upper_bound = lower_bound + i + 1 + tl.assume(upper_bound > lower_bound) + for j in tl.range(lower_bound, upper_bound): + tl.store(out_i_ptr + k, i) + tl.store(out_j_ptr + k, j) + k += 1 + + ub = 10 + out_i = torch.zeros(1024, dtype=torch.int32, device=device) + out_j = torch.zeros(1024, dtype=torch.int32, device=device) + compiled_kernel = kernel[(1, )](ub, out_i, out_j) + assert "tt.flatten" in compiled_kernel.asm["ttir"] + ttgir = compiled_kernel.asm["ttgir"] + ttgir = ttgir[ttgir.find("scf.for"):] + assert ttgir[:ttgir.find("}")].count("scf.for") == 1 + ttgir = ttgir[ttgir.find("}"):] + assert ttgir.count("scf.for") == 1 + + ref_i = torch.zeros(1024, dtype=torch.int32, device=device) + ref_j = torch.zeros(1024, dtype=torch.int32, device=device) + k = 0 + for i in range(ub): + lower_bound = i * 2 + upper_bound = lower_bound + i + 1 + assert upper_bound > lower_bound + for j in range(lower_bound, upper_bound): + ref_i[k] = i + ref_j[k] = j + k += 1 + torch.testing.assert_close(out_i, ref_i, atol=0, rtol=0) + torch.testing.assert_close(out_j, ref_j, atol=0, rtol=0) + + +def test_tl_range_option_none(): + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, num_stages=None, loop_unroll_factor=None): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "num_stages" not in compiled_kernel.asm["ttir"] + assert "loop_unroll_factor" not in compiled_kernel.asm["ttir"] + + +def test_disable_licm(): + + @triton.jit + def while_no_licm(n): + i = 0 + while tl.condition(i < n, disable_licm=True): + i = i + 1 + print("i", i) + + @triton.jit + def while_default(n): + i = 0 + while tl.condition(i < n): + i = i + 1 + print("i", i) + + @triton.jit + def for_no_licm(n): + for i in tl.range(0, n, disable_licm=True): + print("i", i) + + compiled_kernel1 = while_no_licm.warmup(10, grid=(1, )) + assert "llvm.licm.disable" in compiled_kernel1.asm["llir"] + + compiled_kernel2 = while_default.warmup(10, grid=(1, )) + assert "llvm.licm.disable" not in compiled_kernel2.asm["llir"] + + compiled_kernel3 = for_no_licm.warmup(10, grid=(1, )) + assert "llvm.licm.disable" in compiled_kernel3.asm["llir"] + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.interpreter +def test_maxnreg(device): + if not is_cuda(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # reuse the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() + + +@pytest.mark.interpreter +def test_num_programs(device): + # Assuming that the kernel is launched with a grid of (11, 21, 31) + grid = (11, 21, 31) + input = torch.empty((3, ), dtype=torch.int32, device=device) + + @triton.jit + def kernel(input): + num_programs_0 = tl.num_programs(0) + num_programs_1 = tl.num_programs(1) + num_programs_2 = tl.num_programs(2) + tl.store(input, num_programs_0) + tl.store(input + 1, num_programs_1) + tl.store(input + 2, num_programs_2) + + kernel[grid](input) + assert torch.all(input == torch.tensor(grid, device=device)) + + +# ----------------------- +# test loop unrolling +# ----------------------- + + +def test_unroll_attr(device): + + @triton.jit + def _kernel(dst, unroll_factor: tl.constexpr): + pid = tl.program_id(axis=0) + for i in tl.range(0, 10, loop_unroll_factor=unroll_factor): + tl.atomic_add(dst + pid, i + pid) + + def check_loop_unroll_count(ir, opStr, loop_unroll_factor): + for line in ir.splitlines(): + if opStr in line: + loop_unroll_factor = loop_unroll_factor - 1 + # Sometimes we get a remainder loop + assert loop_unroll_factor <= 0 + + # Try for all different loop unroll factors (compile-only): + tmp = torch.empty(1, device=device) + for unroll_factor in [1, 2, 4, 5, 8]: + h = _kernel.warmup(tmp, unroll_factor, grid=(1, )) + check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + +@pytest.mark.interpreter +def test_dtype(device): + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) + + +# stress test slice layout usages in reductions. +@pytest.mark.parametrize("in_shape, perm, red_dims", [ + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), +]) +def test_chained_reductions(in_shape, perm, red_dims, device): + + @triton.jit + def kernel(In, Out, # + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) + vals = tl.load(In + idx) + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) + tl.store(Out + st_idx, r) + + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) + temp = torch.permute(input, perm).contiguous() + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) + result = torch.empty_like(ref) + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) + + assert torch.all(ref == result) + + +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + +@triton.jit +def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, idx_dim0: tl.constexpr, + out_dim0: tl.constexpr): + src_offs = tl.arange(0, src_dim0) + src = tl.load(src_ptr + src_offs) + + idx_offs = tl.arange(0, idx_dim0) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = tl.arange(0, out_dim0) + tl.store(out_ptr + out_offs, out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([32], [64], 0), + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis, device): + if (is_hip_cdna2() or is_hip_cdna3() or is_hip_rdna3() + or is_hip_rdna4()) and src_shape == [128, 64] and indices_shape == [256, 64]: + # This could be solved by reducing vectorization in general swizzling algorithm. + # We will do this if any relevant workload suffers from large LDS consumption of the algorithm. + pytest.skip('Not enough LDS.') + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + if len(src_shape) == 1: + gather_test_kernel_1d[(1, )](src, indices, output, axis, src.shape[0], indices.shape[0], output.shape[0]) + else: + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1)) + + return output + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +@triton.jit +def mul_jit_function(x, y): + return x * y + + +@triton.jit +def apply_binary_op(x, combine_op): + return combine_op(x, x) + + +def test_jit_function_arg(device): + + @triton.jit + def square_kernel_jit_function(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + in_data = tl.load(in_ptr + offsets) + out_data = apply_binary_op(in_data, mul_jit_function) # pass a JITFunction into another JITFunction + tl.store(out_ptr + offsets, out_data) + + BLOCK_SIZE = 16 + x = torch.full((BLOCK_SIZE, ), 3.0, device=device) + out = torch.empty((BLOCK_SIZE, ), device=device) + expect = torch.full((BLOCK_SIZE, ), 9.0, dtype=x.dtype, device=device) + + square_kernel_jit_function[(1, )](x, out, BLOCK_SIZE) + + torch.testing.assert_close(out, expect) + + +@pytest.mark.interpreter +def test_zero_strided_tensors(device): + + @triton.jit + def _simple_add( + X, + stride_x_a, + stride_x_b, + ): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + + # doesn't directly index c dim, so relies on 0-strided c dim to affect every element + x_ptr = X + pid_a * stride_x_a + pid_b * stride_x_b + + tl.atomic_add(x_ptr, 1) + + x = torch.zeros((2, 2, 1), device='cpu').to(device) + c_dim = 3 + x = x.expand((2, 2, c_dim)) + + a, b, c = x.shape + grid = (a, b, c) + with torch.musa.device(x.device.index): + _simple_add[grid](x, x.stride(0), x.stride(1)) + + assert torch.allclose(x, torch.ones_like(x) * c_dim) + + +@pytest.mark.interpreter +def test_aliasing(device): + + @triton.jit + def aliasing_kernel(buffer, buffer2): + triton.language.store(buffer, 1) + + buffer = torch.zeros(1, device=device) + aliasing_kernel[(1, )](buffer, buffer) + assert buffer[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_load(dtype, device): + + @triton.jit + def take_every_second_element(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + strided_offsets) + tl.store(output_ptr + linear_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE // STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.empty(OUT_SIZE, device=device) + take_every_second_element[(1, 1)](x_tri, out_tri, OUT_SIZE) + + # Test that every second element (starting from [0]) from x is stored in out_tri + np.testing.assert_allclose(x[::2], to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_store(dtype, device): + + @triton.jit + def store_into_every_second(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + strided_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE * STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.zeros(OUT_SIZE, device=device) + store_into_every_second[(1, 1)](x_tri, out_tri, SIZE) + + # Test that every second element (starting from [0]) is the same as in x + np.testing.assert_allclose(x, to_numpy(out_tri)[::2]) + # Test that every second element (starting from [1]) is still zero + np.testing.assert_allclose(np.zeros_like(x), to_numpy(out_tri)[1::2]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_load(dtype, device): + + @triton.jit + def indirect_load(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + offsets) + tl.store(output_ptr + linear_offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to load the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_load[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_store(dtype, device): + + @triton.jit + def indirect_store(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to store the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_store[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", map(tl.dtype, tl.dtype.SINT_TYPES + tl.dtype.UINT_TYPES + tl.dtype.STANDARD_FP_TYPES)) +def test_dtype_tensor(device, dtype): + + @triton.jit + def dtype_tensor_kernel(dtype: tl.constexpr): + tensor = tl.zeros((1, ), dtype) + + dtype_tensor_kernel[(1, )](dtype) + + +@pytest.mark.interpreter +def test_short_circuiting(device): + + @triton.jit + def short_circuiting_kernel(x): + if (x is not None) and hasattr(x, "dtype") and isinstance( + x.dtype, tl.pointer_type) and (x.dtype.element_ty == tl.int32) and (tl.load(x) > 42): + tl.store(x, 42) + + def f(x): + short_circuiting_kernel[(1, )](x, num_warps=1) + + f(None) # should succeed with NoneType + f(1) # should succeed with tl.constexpr type + f(2) # should succeed with integer type + + def g(y, dtype): + x = torch.full((1, ), y, device=device, dtype=dtype) + f(x) + return x.item() + + assert g(37.5, torch.float32) == 37.5 + assert g(84.0, torch.float32) == 84.0 + assert g(-76893, torch.int32) == -76893 + assert g(100000, torch.int32) == 42 + assert g(100000, torch.int64) == 100000 + + +@pytest.mark.interpreter +@pytest.mark.filterwarnings("ignore:If conditional called with multidimensional Tensor*") +def test_unsplat(device): + + @triton.jit + def unsplat_kernel(x, explicit: tl.constexpr): + + # this is a single-element tensor: + condition = tl.load(x + tl.arange(0, 1)) > 42 + + if explicit: + condition = condition.item() + + if condition: + tl.store(x, 42) + + def g(y, explicit): + x = torch.full((1, ), y, device=device, dtype=torch.int32) + unsplat_kernel[(1, )](x, explicit, num_warps=1) + return x.item() + + assert g(41, False) == 41 + assert g(43, False) == 42 + assert g(41, True) == 41 + assert g(43, True) == 42 + + +@pytest.mark.interpreter +def test_cumsum_dtype(device): + + @triton.jit + def kernel(Z): + x = tl.full((4, ), True, dtype=tl.int1) + z = tl.cumsum(x, axis=0) + tl.store(Z + tl.arange(0, 4), z) + + z = torch.zeros(4, dtype=torch.int32, device=device) + kernel[(1, )](z) + expected = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device) + assert torch.equal(z, expected) + + +@pytest.mark.interpreter +def test_tensor_member(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 16) + tl.device_assert(tl.abs(x) == x.abs()) + tl.device_assert(tl.sum(x) == x.sum()) + + kernel[(1, )]() + + +@pytest.mark.parametrize("dtype_str", ["float32", "float64"]) +def test_libdevice_rint(dtype_str, device): + iinfo32 = np.iinfo(np.int32) + iinfo64 = np.iinfo(np.int64) + size = 1000 + x0_np = np.random.uniform(iinfo32.min, iinfo32.max + 1, size) + x1_np = np.random.uniform(iinfo64.min, iinfo64.max + 1, size) + x2_np = np.array([-2.5, -1.5, -0.5, -0., 0., 0.5, 1.5, 2.5, float("inf"), -float("inf"), float("nan")]) + x_np = np.concatenate((x0_np, x1_np, x2_np)) + x_tri = to_triton(x_np, device=device, dst_type=dtype_str) + + @triton.jit + def rint_kernel(outp, inp, n, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n + inp_tile = tl.load(inp + offset, mask=mask) + outp_tile = tl.extra.libdevice.rint(inp_tile) + tl.store(outp + offset, outp_tile, mask=mask) + + res_out = torch.empty_like(x_tri) + numel = x_tri.numel() + BLOCK_SIZE = 512 + rint_kernel[(triton.cdiv(numel, BLOCK_SIZE), )](res_out, x_tri, numel, BLOCK_SIZE) + ref_out = np.rint(x_np) + np.testing.assert_allclose(to_numpy(res_out), ref_out, rtol=0, atol=0, equal_nan=True) diff --git a/third_party/mthreads/python/test/unit/language/test_decorator.py b/third_party/mthreads/python/test/unit/language/test_decorator.py new file mode 100644 index 0000000000..42207cc1fa --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_decorator.py @@ -0,0 +1,50 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/mthreads/python/test/unit/language/test_frontend.py b/third_party/mthreads/python/test/unit/language/test_frontend.py new file mode 100644 index 0000000000..0ed15da2b3 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_frontend.py @@ -0,0 +1,618 @@ +import functools +import triton +import triton.language as tl +from triton._filecheck import filecheck_test, run_filecheck_test, run_parser +from triton.compiler.errors import CompilationError +import pytest +from typing import NamedTuple + +# ===-----------------------------------------------------------------------===# +# Unit Tests +# ===-----------------------------------------------------------------------===# + + +def doesnt_compile(kernel): + + @functools.wraps(kernel) + def test_fn(): + with pytest.raises(triton.CompilationError): + run_parser(kernel) + + return test_fn + + +@triton.jit +def anchor(v): + pass + + +@tl.core._aggregate +class Pair: + first: tl.tensor + second: tl.tensor + + def __init__(self, first, second): + self.first = first + self.second = second + + @triton.jit + def get_first(self): + return self.first + + def get_second(self, _semantic=None): + return self.second + + @triton.jit + def unpack(self): + return self.get_first(), self.get_second() + + def __getitem__(self, ind: tl.constexpr, _semantic=None): + if ind == 0: + return self.first + assert ind == 1 + return self.second + + def __setitem__(self, ind: tl.constexpr, value, _semantic=None): + if ind == 0: + self.first = value + assert ind == 1 + self.second = value + + +@doesnt_compile +@triton.jit +def test_assign_attribute(): + scalar = 11 + pair = Pair(tl.arange(0, 4), scalar) + pair.second = 42 + + +@doesnt_compile +@triton.jit +def test_augassign_attribute(): + scalar = 11 + pair = Pair(tl.arange(0, 4), scalar) + pair.second += 42 + + +@filecheck_test +@triton.jit +def test_retrieve_item(): + # CHECK-LABEL: test_retrieve_item + # CHECK: %c11_i32 = arith.constant 11 : i32 + # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} + scalar = 11 + pair = Pair(tl.arange(0, 4), scalar) + # CHECK-NEXT: call @{{.*}}anchor{{.*}}(%c11_i32) + anchor(pair[1]) + + +@doesnt_compile +@triton.jit +def test_assign_item(): + scalar = 11 + pair = Pair(tl.arange(0, 4), scalar) + pair[1] = 42 + + +@doesnt_compile +@triton.jit +def test_augassign_item(): + scalar = 11 + pair = Pair(tl.arange(0, 4), scalar) + pair[1] += 42 + + +@filecheck_test +@triton.jit +def test_jit_method(): + # CHECK-LABEL: test_jit_method + # CHECK: %c11_i32 = arith.constant 11 : i32 + # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} + scalar = 11 + # CHECK: [[V:%.*]]:2 = tt.call @{{.*}}unpack{{.*}}([[RANGE]], %c11_i32) + pair = Pair(tl.arange(0, 4), scalar) + a, b = pair.unpack() + # CHECK: call @{{.*}}anchor{{.*}}([[V]]#0) + anchor(a) + # CHECK: call @{{.*}}anchor{{.*}}([[V]]#1) + anchor(b) + + +@tl.core._aggregate +class TypeWithJitGetItem: + value: tl.tensor + + def __init__(self, value): + self.value = value + + @triton.jit + def __getitem__(self, ind): + return self.value + + +@filecheck_test +@triton.jit +def test_jit_getitem(): + # CHECK-LABEL: test_jit_getitem + # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} + v = TypeWithJitGetItem(tl.arange(0, 4)) + # CHECK: [[V:%.*]] = tt.call [[METHOD:@.*__getitem__.*]]([[RANGE]]) + a = v[0] + # CHECK: call @{{.*}}anchor{{.*}}([[V]]) + anchor(a) + # CHECK: tt.func private [[METHOD]]([[ARG0:%.*]]: + # CHECK: tt.return [[ARG0]] + + +@tl.core._aggregate +class TypeWithBuiltinInitializer: + value: tl.tensor + + def __init__(self, _semantic=None): + self.value = tl.arange(0, 4, _semantic=_semantic) + + +@filecheck_test +@triton.jit +def test_aggregate_initializers(): + # CHECK-LABEL: test_aggregate_initializers + value = TypeWithBuiltinInitializer() + # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} + # CHECK: call @{{.*}}anchor{{.*}}([[RANGE]]) + anchor(value) + + +@triton.jit +def forward(arg): + return arg + + +@triton.jit +def list_of_functions_constexpr(arg, fns: tl.constexpr): + for i in tl.static_range(len(fns)): + fns[i](arg) + + +@filecheck_test +@triton.jit +def test_list_of_functions(): + # CHECK-LABEL: test_list_of_functions + # CHECK: call @{{.*}}list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward) + + # CHECK: tt.func private @{{.*}}list_of_functions_constexpr + # CHECK-NEXT: call @{{.*}}anchor + # CHECK-NEXT: call @{{.*}}forward + list_of_functions_constexpr(tl.arange(0, 4), [anchor, forward]) + + +@triton.jit +def accumulate(a, b): + return a + b + + +# Check that we can call a function returning a value from a loop. +@filecheck_test +@triton.jit +def test_call_in_loop(): + # CHECK-LABEL: test_call_in_loop + acc = 0 + # CHECK: scf.for + # CHECK: call @{{.*}}accumulate + for i in range(10): + acc = accumulate(acc, i) + + +@tl.core._aggregate +class FunctionParent: + + @triton.jit + def function_with_name(): + pass + + +@triton.jit +def function_with_name(): + pass + + +@filecheck_test +@triton.jit +def test_function_name_mangling(): + # CHECK-LABEL: test_function_name_mangling + # CHECK: call @test_frontend.function_with_name + # CHECK: call @test_frontend.FunctionParent.function_with_name + function_with_name() + FunctionParent.function_with_name() + + +@tl.core._aggregate +class AggregateWithConstexpr: + a: tl.tensor + b: tl.constexpr + + def __init__(self, a, b): + self.a = a + self.b = b + + @staticmethod + def create(a): + return AggregateWithConstexpr(a, tl.constexpr(42)) + + @triton.jit + def modify(self, a): + self.a = a + return self + + +@triton.jit +def add_rhs_constexpr(agg): + _ = agg.a + agg.b + + +@filecheck_test +@triton.jit +def test_aggregate_with_constexpr(): + # CHECK-LABEL: test_aggregate_with_constexpr + # CHECK: tt.call @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr + agg = AggregateWithConstexpr.create(tl.arange(0, 4)) + add_rhs_constexpr(agg) + + # CHECK: tt.func private @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr + # CHECK: %cst = arith.constant dense<42> : tensor<4xi32> + # CHECK: arith.addi %arg0, %cst : tensor<4xi32> + + +@tl.core._aggregate +class AggregateWithTuple: + a: tl.tuple + + @triton.constexpr_function + def __init__(self, a): + self.a = tl.tuple((a, )) + + @staticmethod + @triton.jit + def create(a): + return AggregateWithTuple(a) + + +@triton.jit +def pass_tuple_aggregate(agg): + pass + + +@filecheck_test +@triton.jit +def test_aggregate_with_tuple(): + # CHECK-LABEL: test_aggregate_with_tuple + # CHECK: tt.call @"test_frontend.pass_tuple_aggregate__test_frontend.AggregateWithTuple" + agg = AggregateWithTuple.create(tl.arange(0, 4)) + pass_tuple_aggregate(agg) + # CHECK: tt.func private @"test_frontend.pass_tuple_aggregate__test_frontend.AggregateWithTuple" + + +@triton.constexpr_function +def constexpr_function(x): + return x + 1 + + +@filecheck_test +@triton.jit +def test_constexpr_function_from_jit(): + # CHECK-LABEL: test_constexpr_function + x: tl.constexpr = constexpr_function(7) + # CHECK: make_range {end = 8 : i32, start = 0 : i32} + tl.arange(0, x) + + +def test_constexpr_function_from_python(): + assert constexpr_function(7) == 8 + + +@triton.jit +def swap(pair): + return pair.second, pair.first + + +@doesnt_compile +@triton.jit +def test_assign_tuple_attrs_kernel(): + p = Pair(tl.arange(0, 4), tl.arange(4, 8)) + p.first, p.second = swap(p) + + +@doesnt_compile +@triton.jit +def test_reassign_aggregate_with_constexpr(): + agg = AggregateWithConstexpr.create(tl.arange(0, 4)) + agg = agg.modify(tl.arange(4, 8)) + + +@triton.constexpr_function +def make_shape(m, n): + return (m, n) + + +@triton.constexpr_function +def add_shape_dims(m, n): + return m + n + + +@filecheck_test +@triton.jit +def test_constexpr_getitem(): + # CHECK-LABEL: test_constexpr_getitem + # CHECK: make_range {end = 12 : i32, start = 4 : i32} + shape: tl.constexpr = make_shape(4, 8) + sum: tl.constexpr = add_shape_dims(shape[0], shape[1]) + tl.arange(4, sum) + + +@triton.constexpr_function +def Box(T): + + @tl.core._aggregate + class BoxImpl: + value: T + + @triton.jit + def create(value): + return BoxImpl(value) + + def __init__(self, value): + self.value = value + + return BoxImpl + + +def test_late_bound_class_reference(): + TensorBox = Box(tl.tensor) + + @triton.jit + def kernel(): + # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} + # CHECK: call @{{.*}}anchor{{.*}}([[RANGE]]) + value = TensorBox(tl.arange(0, 4)) + anchor(value) + + run_filecheck_test(kernel) + + +@triton.jit +def recursive_reduce(x): + if x.shape[0] == 1: + return x + else: + x0, x1 = x.reshape((x.shape[0] // 2, 2)).split() + return recursive_reduce(x0) + recursive_reduce(x1) + + +@filecheck_test +@triton.jit +def test_specialized_recursion(): + # CHECK-LABEL: test_specialized_recursion + # CHECK: call {{.*}}recursive_reduce__i32S16S + x = tl.arange(0, 16) + recursive_reduce(x) + + # CHECK: func {{.*}}recursive_reduce__i32S16S + # CHECK-COUNT-2: call {{.*}}recursive_reduce__i32S8S + + # CHECK: func {{.*}}recursive_reduce__i32S8S + # CHECK-COUNT-2: call {{.*}}recursive_reduce__i32S4S + + # CHECK: func {{.*}}recursive_reduce__i32S4S + # CHECK-COUNT-2: call {{.*}}recursive_reduce__i32S2S + + +@triton.jit +def trivial_return(): + return + + +@filecheck_test +@triton.jit +def test_call_in_while(): + # CHECK-LABEL: test_call_in_while + i = 0 + while i < 10: + if i == 5: + trivial_return() + else: + trivial_return() + + +def test_return_in_while(): + + @triton.jit + def kernel(): + i = 0 + while i < 10: + if i == 5: + return + i += 1 + + with pytest.raises(CompilationError) as e: + run_parser(kernel) + + assert "Cannot have `return` statements inside `while` or `for` statements in triton" in str(e.value) + + +class TensorPtr(NamedTuple): + test: tl.constexpr + + +class TestTuple(NamedTuple): + __test__ = False + test: TensorPtr + + +@triton.jit +def foo(test: TestTuple): + x: tl.constexpr = tl.constexpr(1) + for i in tl.range(x): + # Tests that it compiles and is usable. + tl.static_assert(test.test.test == 1) + + +def test_tuple_constexpr(): + test = TestTuple(test=TensorPtr(tl.constexpr(1))) + run_parser(foo, args=(test, )) + + +@tl.core._aggregate +class AggregateWithConstexprFunction: + val: tl.constexpr + val_squared: tl.constexpr + + def __init__(self, val): + self.val = tl.constexpr(val) + self.val_squared = tl.constexpr(self.square_val()) + + @triton.constexpr_function + def square_val(self): + return self.val * self.val + + +@filecheck_test +@triton.jit +def test_aggregate_constexpr_function(): + agg = AggregateWithConstexprFunction(4) + # CHECK: call @{{.*}}anchor{{.*}}c4 + anchor(agg.val) + + # CHECK: call @{{.*}}anchor{{.*}}c16 + anchor(agg.val_squared) + + # CHECK: call @{{.*}}anchor{{.*}}c16 + anchor(agg.square_val()) + + +@tl.core.builtin +def make_list(*args, _semantic=None): + return list(args) + + +@triton.constexpr_function +def function_taking_list(arg): + return arg[1] + + +@filecheck_test +@triton.jit +def test_constexpr_function_taking_list(): + a: tl.constexpr = function_taking_list(make_list(4, 8, 16)) + # CHECK: call @{{.*}}anchor{{.*}}c8 + anchor(a) + + +@filecheck_test +@triton.jit +def test_constexpr_min_max(): + a: tl.constexpr = min(1, 2) + # CHECK: call @{{.*}}anchor{{.*}}c1 + anchor(a) + + b: tl.constexpr = min(1, 2, -3) + # CHECK: call @{{.*}}anchor{{.*}}c-3 + anchor(b) + + c: tl.constexpr = max(3, 4) + # CHECK: call @{{.*}}anchor{{.*}}c4 + anchor(c) + + d: tl.constexpr = max(3, 4, 5) + # CHECK: call @{{.*}}anchor{{.*}}c5 + anchor(d) + + +def test_constexpr_min_error(): + + @triton.jit + def min_kernel(a: tl.constexpr, b: tl.constexpr): + min(a, b) + + with pytest.raises(CompilationError): + run_parser(min_kernel, args=(1.0, float("nan"))) + + with pytest.raises(CompilationError): + run_parser(min_kernel, args=(1.0, -0.0)) + + +def test_constexpr_max_error(): + + @triton.jit + def max_kernel(a: tl.constexpr, b: tl.constexpr): + max(a, b) + + with pytest.raises(CompilationError): + run_parser(max_kernel, args=(1.0, float("nan"))) + + with pytest.raises(CompilationError): + run_parser(max_kernel, args=(1.0, -0.0)) + + +@filecheck_test +@triton.jit +def test_for_loop_iv_modification(): + # CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step {{.*}} : i32 { + for i in range(4): + # CHECK: anchor{{.*}}%[[I]] + anchor(i) + # CHECK: %[[I2:.*]] = arith.addi %[[I]], %{{.*}} : i32 + i += 1 + # CHECK: anchor{{.*}}%[[I2]] + anchor(i) + + +@pytest.mark.interpreter +def test_constexpr_return(): + + @triton.jit + def get_constexpr_value(): + return tl.constexpr(42) + + @triton.jit + def test(): + x: tl.constexpr = get_constexpr_value() + tl.static_assert(x == 42) + + run_parser(test) + + +@pytest.mark.interpreter +def test_return_promotion(): + + @triton.jit + def signbit(x): + if x < 0: + return 1 + else: + return 0 + + @triton.jit + def tuple_return(x): + if x < 0: + return 1, x + else: + return 0, x + + @triton.jit + def kernel(): + # constexpr if -> constexpr returned + a: tl.constexpr = signbit(-1) + tl.static_assert(a == 1) + + # dynamic if -> promote to tensor + tmp = -1 + tl.static_assert(signbit(tmp).type == tl.int32) + + # constexpr if -> single return + b: tl.constexpr = tuple_return(-1) + tl.static_assert(b[0] == 1 and b[1] == -1) + + c = tuple_return(tmp) + tl.static_assert(c.type == tl.tuple_type([tl.int32, tl.int32])) + + run_parser(kernel) diff --git a/third_party/mthreads/python/test/unit/language/test_libdevice.py b/third_party/mthreads/python/test/unit/language/test_libdevice.py new file mode 100644 index 0000000000..7c68857739 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_libdevice.py @@ -0,0 +1,52 @@ +import pytest +import torch + +import triton +import triton.language as tl + +from triton.language.extra import libdevice +from triton.language.extra.libdevice import fast_dividef as my_fast_dividef + + +def test_libdevice_rename(device): + # mark the import as used by this test + _ = my_fast_dividef + + @triton.jit + def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.load(in_ptr + offsets) + tl.store(out_ptr + offsets, data) + + BLOCK_SIZE = 256 + inp = torch.randn(BLOCK_SIZE, device=device) + out = torch.empty_like(inp) + + triton_copy[(1, )](inp, out, BLOCK_SIZE) + + +@pytest.mark.parametrize("dtype_str", ["float32", "float64"]) +def test_isinf(device, dtype_str): + + @triton.jit + def triton_isinf(in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + in_tile = tl.load(in_ptr + offsets, mask=mask) + if in_ptr.dtype.element_ty == tl.float32: + out_tile = libdevice.finitef(in_tile) + else: + out_tile = libdevice.isfinited(in_tile) + tl.store(out_ptr + offsets, out_tile, mask=mask) + + x = torch.tensor( + [float(1), -float(1), + float(0), -float(0), + float("inf"), -float("inf"), + float("nan"), -float("nan")], device=device, dtype=getattr(torch, dtype_str)) + res = torch.tensor([True, True, True, True, False, False, False, False]) + numel = x.numel() + y = torch.empty_like(x, dtype=torch.bool) + BLOCK_SIZE = 256 + triton_isinf[(triton.cdiv(numel, BLOCK_SIZE), )](x, y, numel, BLOCK_SIZE) + assert torch.equal(y.cpu(), res) diff --git a/third_party/mthreads/python/test/unit/language/test_line_info.py b/third_party/mthreads/python/test/unit/language/test_line_info.py new file mode 100644 index 0000000000..f90c542943 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_line_info.py @@ -0,0 +1,447 @@ +import inspect +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl +from triton._internal_testing import is_interpreter +from triton._filecheck import run_filecheck + + +@triton.jit +def kernel_single(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +# Call another jit function (cdiv) not in this file +@triton.jit +def kernel_cdiv(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + d = tl.cdiv(c, 4) + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + nvdisasm = triton.knobs.nvidia.nvdisasm.path + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1 and file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine", "cdiv"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1, ))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1, )) + elif func == "cdiv": + kernel_info = kernel_cdiv.warmup(20, grid=(1, )) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + assert (check_file_lines(file_lines, "test_line_info.py", 17)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 27)) + assert (check_file_lines(file_lines, "test_line_info.py", 29)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 41)) + assert (check_file_lines(file_lines, "test_line_info.py", 34)) + assert (check_file_lines(file_lines, "test_line_info.py", 34)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 52)) + assert (check_file_lines(file_lines, "test_line_info.py", 53)) + assert (check_file_lines(file_lines, "test_line_info.py", 54)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 64)) + assert (check_file_lines(file_lines, "test_line_info.py", 65, should_contain=False)) + elif func == "cdiv": + assert (check_file_lines(file_lines, "test_line_info.py", 74)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func", func_types) +def test_line_info_interpreter(func: str): + if not is_interpreter(): + pytest.skip("interpreter is not enabled") + + kernel = None + expected_def_lineno = 0 + if func == "single": + kernel = kernel_single + expected_def_lineno = 15 + elif func == "call": + kernel = kernel_call + expected_def_lineno = 26 + elif func == "call_noinline": + kernel = kernel_call_noinline + expected_def_lineno = 40 + elif func == "autotune": + kernel = kernel_autotune.fn + expected_def_lineno = 51 + elif func == "dot_combine": + kernel = kernel_dot_combine + expected_def_lineno = 61 + elif func == "cdiv": + kernel = kernel_cdiv + expected_def_lineno = 71 + kernel.rewrite() + assert kernel.rewriter.def_file_lineno == expected_def_lineno + + +@pytest.mark.parametrize("status", ["0", "1"]) +def test_line_info_env(monkeypatch, status: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + monkeypatch.setenv("TRITON_DISABLE_LINE_INFO", status) + kernel_single.device_caches.clear() + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + assert len(file_lines) == 0 if status == "1" else len(file_lines) > 0 + + +@pytest.mark.parametrize("status", ["ttir", ""]) +def test_line_info_ir_source(monkeypatch, status, tmp_path, fresh_triton_cache): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + src = """ + #loc = loc("/path/test.py":7:0) + module { + tt.func public @test(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/path/test.py":7:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/path/test.py":7:0)) attributes {noinline = false} { + %0 = tt.load %arg0 : !tt.ptr loc(#loc1) + tt.store %arg1, %0 : !tt.ptr loc(#loc2) + tt.return loc(#loc3) + } loc(#loc) + } loc(#loc) + #loc1 = loc("/path/test.py":8:16) + #loc2 = loc("/path/test.py":9:20) + #loc3 = loc("/path/test.py":9:4) + """ + monkeypatch.setenv("USE_IR_LOC", status) + temp_file = tmp_path / "test.ttir" + temp_file.write_text(src) + kernel_info = triton.compile(str(temp_file)) + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if status == "ttir": + assert check_file_lines(file_lines, "/path/test.py", 8, should_contain=False) + assert check_file_lines(file_lines, str(temp_file), -1, should_contain=True) + else: + assert check_file_lines(file_lines, "/path/test.py", 8, should_contain=True) + + +def test_use_name_loc_as_prefix(fresh_triton_cache): + + @triton.jit + def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr): + # CHECK: #loc = loc("{{.*}}":261:0) + # CHECK-LABEL: tt.func public @kernel_basic( + # CHECK-SAME: %src: !tt.ptr loc("src"(#loc)), %N: i32 loc("N"(#loc))) + # CHECK: %x_plus_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc14) + # CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc2) + # CHECK: %pid = tt.get_program_id x : i32 loc(#loc15) + # CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc16) + # CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc17) + # CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc18) + # CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc18) + # CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc19) + # CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc19) + # CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc20) + # CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc20) + # CHECK: %x_plus_1_4 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr> loc(#loc21) + # CHECK: %x_plus_1_5 = arith.addf %x_plus_1_4, %x_plus_1 : tensor<16xf32> loc(#loc14) + # CHECK: tt.store %load_src_store_dst_2, %x_plus_1_5, %mask_3 : tensor<16x!tt.ptr> loc(#loc10) + # CHECK: tt.return loc(#loc11) + # CHECK: } loc(#loc) + # CHECK: } loc(#loc) + + # CHECK: #loc1 = loc({{.*}}) + # CHECK: #loc2 = loc(unknown) + # CHECK: #loc3 = loc({{.*}}) + # CHECK: #loc4 = loc({{.*}}) + # CHECK: #loc5 = loc({{.*}}) + # CHECK: #loc6 = loc({{.*}}) + # CHECK: #loc7 = loc({{.*}}) + # CHECK: #loc8 = loc({{.*}}) + # CHECK: #loc9 = loc({{.*}}) + # CHECK: #loc10 = loc({{.*}}) + # CHECK: #loc11 = loc({{.*}}) + # CHECK: #loc14 = loc("x_plus_1"(#loc1)) + # CHECK: #loc15 = loc("pid"(#loc3)) + # CHECK: #loc16 = loc("offset"(#loc4)) + # CHECK: #loc17 = loc("offsets"(#loc5)) + # CHECK: #loc18 = loc("offsets"(#loc6)) + # CHECK: #loc19 = loc("load_src_store_dst"(#loc7)) + # CHECK: #loc20 = loc("mask"(#loc8)) + # CHECK: #loc21 = loc("x_plus_1"(#loc9)) + + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + offsets = offset + tl.arange(0, BLOCK_SIZE) + load_src_store_dst = src + offsets + mask = offsets < N + x_plus_1 = tl.load(load_src_store_dst, mask=mask) + 1 + tl.store(load_src_store_dst, x_plus_1, mask=mask) + + h = triton.compile( + triton.compiler.ASTSource(fn=kernel_basic, signature={"src": "*fp32", "N": "i32", "BLOCK_SIZE": "constexpr"}, + constexprs={"BLOCK_SIZE": 16})) + + check_template = inspect.getsource(kernel_basic.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + @triton.jit + def kernel_basic_for_loop(N): + # CHECK-LABEL: tt.func public @kernel_basic_for_loop + + # CHECK: scf.for %ivar = %c0_i32 to %N step %c1_i32 + for ivar in range(N): + tl.device_print("", ivar) + + h = triton.compile(triton.compiler.ASTSource(fn=kernel_basic_for_loop, signature={"N": "i32"}, constexprs={})) + + check_template = inspect.getsource(kernel_basic_for_loop.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + @triton.jit + def kernel_basic_for_loop_with_block_args(N): + # CHECK-LABEL: tt.func public @kernel_basic_for_loop_with_block_args + + # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + arange = tl.arange(0, 16) + # CHECK: %arange_0 = scf.for %ivar = %c0_i32 to %N step %c1_i32 iter_args(%arange_1 = %arange) -> (tensor<16xi32>) + for ivar in range(N): + # CHECK: %arange_2 = arith.addi %arange_1, %arange_1 : tensor<16xi32> + arange += arange + # scf.yield %arange_2 : tensor<16xi32> + + tl.device_print("", arange) + + h = triton.compile( + triton.compiler.ASTSource(fn=kernel_basic_for_loop_with_block_args, signature={"N": "i32"}, constexprs={})) + + check_template = inspect.getsource(kernel_basic_for_loop_with_block_args.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + @triton.jit + def kernel_basic_if(N): + # CHECK-LABEL: tt.func public @kernel_basic_if + + # CHECK-DAG: %cst = arith.constant dense<4> : tensor<16xi32> + # CHECK-DAG: %cst_0 = arith.constant dense<2> : tensor<16xi32> + + # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + arange = tl.arange(0, 16) + + if N > 2: + # CHECK: %arange_1 = arith.muli %arange, %cst_0 : tensor<16xi32> + arange *= 2 + # CHECK: scf.yield %arange_1 : tensor<16xi32> + else: + # CHECK: %arange_1 = arith.muli %arange, %cst : tensor<16xi32> + arange *= 4 + # CHECK: scf.yield %arange_1 : tensor<16xi32> + + tl.device_print("", arange) + + h = triton.compile(triton.compiler.ASTSource(fn=kernel_basic_if, signature={"N": "i32"}, constexprs={})) + + check_template = inspect.getsource(kernel_basic_if.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + @triton.jit + def kernel_basic_if_top_level(N): + # CHECK-LABEL: tt.func public @kernel_basic_if_top_level + + # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + arange = tl.arange(0, 16) + if N == 0: + # CHECK: %arange_0 = arith.addi %arange, %arange : tensor<16xi32> + arange += tl.arange(0, 16) + tl.device_print("", arange) + return + else: + # CHECK: %new_arange = tt.make_range {end = 32 : i32, start = 16 : i32} : tensor<16xi32> + new_arange = tl.arange(16, 32) + # CHECK: %arange_1 = arith.addi %arange, %new_arange : tensor<16xi32> + arange += new_arange + tl.device_print("", arange) + return + + h = triton.compile(triton.compiler.ASTSource(fn=kernel_basic_if_top_level, signature={"N": "i32"}, constexprs={})) + + check_template = inspect.getsource(kernel_basic_if_top_level.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + @triton.jit + def kernel_basic_while(N): + # CHECK-LABEL: tt.func public @kernel_basic_while + + # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + arange = tl.arange(0, 16) + ivar = 0 + # CHECK: %ivar_[[IV0:.+]]:2 = scf.while (%arange_[[AR0:.+]] = %arange, %ivar_[[IV1:.+]] = %ivar) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32) + # CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_[[IV1]], %N : i32 + # CHECK: scf.condition(%[[COND]]) %arange_[[AR0]], %ivar_[[IV1]] : tensor<16xi32>, i32 + while ivar < N: + # CHECK: ^bb0(%arange_[[AR0]]: tensor<16xi32> loc("arange"), %ivar_[[IV1]]: i32 + + # CHECK: %ivar_[[IV2:.+]] = arith.addi %ivar_[[IV1]], %c1_i32 : i32 + ivar += 1 + # CHECK: %arange_[[AR1:.+]] = tt.splat %ivar_[[IV2]] : i32 -> tensor<16xi32> + # CHECK: %arange_[[AR2:.+]] = arith.muli %arange_[[AR0]], %arange_[[AR1]] : tensor<16xi32> + # CHECK: scf.yield %arange_[[AR2]], %ivar_[[IV2]] : tensor<16xi32>, i32 + arange *= ivar + + # CHECK: tt.print ": " {hex = false, isSigned = array} : %ivar_[[IV0]]#0 : tensor<16xi32> + tl.device_print("", arange) + + h = triton.compile(triton.compiler.ASTSource(fn=kernel_basic_while, signature={"N": "i32"}, constexprs={})) + check_template = inspect.getsource(kernel_basic_while.fn) + run_filecheck("placeholder", h.asm["ttir"], check_template) + + +def test_map_elementwise_has_lineinfo(): + + @triton.jit + def compare(x, y): + if x < y: + return x + return y + + @triton.jit + def kernel(X, Y): + # CHECK-NOT: loc(unknown) + x = tl.load(X + tl.arange(0, 4)) + y = tl.load(Y + tl.arange(0, 4)) + z = tl.map_elementwise(compare, x, y) + tl.device_print("", z) + + kernel_info = kernel.warmup(torch.float32, torch.float32, grid=(1, )) + check_template = inspect.getsource(kernel.fn) + run_filecheck("test", kernel_info.asm["ttir"], check_template) diff --git a/third_party/mthreads/python/test/unit/language/test_module.py b/third_party/mthreads/python/test/unit/language/test_module.py new file mode 100644 index 0000000000..27a49efd1d --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_module.py @@ -0,0 +1,6 @@ +import triton + + +@triton.jit +def function_with_name(): + pass diff --git a/third_party/mthreads/python/test/unit/language/test_musa_ut_056.py b/third_party/mthreads/python/test/unit/language/test_musa_ut_056.py new file mode 100644 index 0000000000..c8cbef4fe4 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_musa_ut_056.py @@ -0,0 +1,152 @@ +import os + +os.environ.setdefault("TRITON_BACKENDS_IN_TREE", "1") + +import pytest +import triton +import triton.language as tl +from triton._C import libtriton +from pathlib import Path + +if not hasattr(libtriton, "musa"): + pytest.skip("musa backend not built in libtriton", allow_module_level=True) + +from triton.backends import backends +from triton.backends.compiler import GPUTarget +from triton.compiler import ASTSource +from triton._C.libtriton import ir + + +def _get_musa_backend(): + if "musa" not in backends: + pytest.skip("musa backend not discovered") + target = GPUTarget("musa", "ph1", 32) + return backends["musa"].compiler(target) + + +def _compile_to_llir(fn, signature, constexprs=None): + target = GPUTarget("musa", "ph1", 32) + backend = _get_musa_backend() + + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + options = backend.parse_options({}) + module_map = backend.get_module_map() + codegen_fns = backend.get_codegen_implementation(options) + src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) + + ttir = src.make_ir(target, options, codegen_fns, module_map, context) + stages = {} + backend.add_stages(stages, options, src.language) + meta = {} + ttir = stages["ttir"](ttir, meta) + ttgir = stages["ttgir"](ttir, meta) + llir = stages["llir"](ttgir, meta) + return llir, meta + + +def test_musa_056_default_libdevice_path(fresh_knobs): + backend = _get_musa_backend() + from triton.backends.musa import compiler as musa_compiler + + with fresh_knobs.musa.scope(): + del fresh_knobs.musa.libdevice_path + options = backend.parse_options({}) + + expected = Path(musa_compiler.__file__).resolve().parent / "lib" / "libdevice.31.bc" + assert Path(dict(options.extern_libs)["libdevice"]).resolve() == expected + + +def test_musa_056_libdevice_path_override(fresh_knobs, tmp_path): + backend = _get_musa_backend() + override = tmp_path / "libdevice.override.bc" + override.write_bytes(b"") + + with fresh_knobs.musa.scope(): + fresh_knobs.musa.libdevice_path = str(override) + options = backend.parse_options({}) + + assert dict(options.extern_libs)["libdevice"] == str(override) + + +def test_musa_056_cast_compile_only(): + + @triton.jit + def kernel_cast(inp, out): + offs = tl.arange(0, 64) + x = tl.load(inp + offs) + y = x.to(tl.float16) + z = y.to(tl.float32) + tl.store(out + offs, z) + + llir, _ = _compile_to_llir(kernel_cast, {"inp": "*fp32", "out": "*fp32"}) + assert "fptrunc" in llir + assert "fpext" in llir + + +def test_musa_056_chained_dot_compile_only(): + + @triton.jit + def kernel_chained_dot(out): + a = tl.full((16, 16), 1.0, tl.float16) + b = tl.full((16, 16), 2.0, tl.float16) + c = tl.dot(a, b) + d = tl.dot(c.to(tl.float16), a) + row = tl.sum(d, axis=1) + offs = tl.arange(0, 16) + tl.store(out + offs, row.to(tl.float32)) + + llir, meta = _compile_to_llir(kernel_chained_dot, {"out": "*fp32"}) + assert "target datalayout" in llir + assert "shared" in meta + + +@pytest.mark.parametrize("input_precision", ["bf16x3", "bf16x6"]) +def test_musa_056_bf16xN_dot_compile_only(input_precision): + + @triton.jit + def kernel_bf16_dot(out, INPUT_PRECISION: tl.constexpr): + a = tl.full((16, 16), 1.0, tl.float32) + b = tl.full((16, 16), 2.0, tl.float32) + c = tl.dot(a, b, input_precision=INPUT_PRECISION, out_dtype=tl.float32) + row = tl.sum(c, axis=1) + offs = tl.arange(0, 16) + tl.store(out + offs, row) + + llir, _ = _compile_to_llir( + kernel_bf16_dot, + {"out": "*fp32", "INPUT_PRECISION": "constexpr"}, + constexprs={"INPUT_PRECISION": input_precision}, + ) + assert "target datalayout" in llir + + +def test_musa_056_functional_vecmat_compile_only(): + + @triton.jit + def kernel_vecmat(inp, out): + offs = tl.arange(0, 16) + vec = tl.load(inp + offs) + mat = tl.full((16, 16), 0.5, tl.float32) + prod = mat * tl.expand_dims(vec, 0) + red = tl.sum(prod, axis=1) + tl.store(out + offs, red) + + llir, _ = _compile_to_llir(kernel_vecmat, {"inp": "*fp32", "out": "*fp32"}) + assert "fadd" in llir + assert "fmul" in llir + + +def test_musa_056_constexpr_annotation_compile_only(): + + @triton.jit + def kernel_constexpr(inp, out, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + x = tl.load(inp + offs) + tl.store(out + offs, x) + + llir, _ = _compile_to_llir(kernel_constexpr, {"inp": "*fp32", "out": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 32}) + assert "target datalayout" in llir diff --git a/third_party/mthreads/python/test/unit/language/test_mxfp.py b/third_party/mthreads/python/test/unit/language/test_mxfp.py new file mode 100644 index 0000000000..3e0d6c050e --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_mxfp.py @@ -0,0 +1,127 @@ +import pytest +import torch +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor + + +class MXBaseTest: + + @pytest.fixture + def device(self): + return "cpu" + + +class TestMXFP4Tensor(MXBaseTest): + + @pytest.mark.parametrize("K, N", [(64, 128), (128, 256)]) + def test_roundtrip(self, K, N, device): + tensor = MXFP4Tensor(size=(K, N), device=device).random() + tensor2 = MXFP4Tensor(tensor.to(torch.float32)) + torch.testing.assert_close(tensor.data, tensor2.data) + + @pytest.mark.parametrize("K, N, dim", [(64, 128, 0), (64, 128, 1)]) + def test_packed_tensor(self, K, N, dim, device): + tensor = MXFP4Tensor(size=(K, N), device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=(K, N)) + torch.testing.assert_close(tensor.data, unpacked) + + def test_padding(self, device): + tensor_pad = MXFP4Tensor(torch.tensor([4], device=device)) + pad_packed = tensor_pad.to_packed_tensor(dim=0) + torch.testing.assert_close(tensor_pad.data, + tensor_pad.unpack_packed_tensor(pad_packed, dim=0, original_shape=(1, ))) + + def test_zero_values(self, device): + test_values = torch.tensor([0.0, -0.0], device=device) + tensor = MXFP4Tensor(test_values) + expected_encodings = torch.tensor([0b0000, 0b1000], dtype=torch.uint8, device=device) + assert torch.equal(tensor.data, expected_encodings), "Zero values should be encoded as 0" + torch.testing.assert_close(tensor.to(torch.float32), test_values) + + def test_out_of_range_values(self, device): + test_values = torch.tensor([7.0, -7.0, float('inf'), float('-inf')], device=device) + tensor = MXFP4Tensor(test_values) + expected_values = torch.tensor([6.0, -6.0, 6.0, -6.0], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_subnormal_numbers(self, device): + test_values = torch.tensor([0.1, 0.2, 0.3, 0.4], device=device) + tensor = MXFP4Tensor(test_values) + expected_values = torch.tensor([0.0, 0.0, 0.5, 0.5], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_rounding_edge_cases(self, device): + test_values = torch.tensor([0.75, 1.25, 1.75, 2.5, 3.5, 5.0], device=device) + expected_values = torch.tensor([1.0, 1.0, 2.0, 2.0, 4.0, 4.0], device=device) + tensor = MXFP4Tensor(test_values) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_negative_values(self, device): + test_values = torch.tensor([-0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], device=device) + tensor = MXFP4Tensor(test_values) + torch.testing.assert_close(tensor.to(torch.float32), test_values) + + def test_negative_out_of_range(self, device): + tensor = MXFP4Tensor(torch.tensor([-7.0, -8.0, -10.0], device=device)) + expected_values = torch.tensor([-6.0, -6.0, -6.0], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + @pytest.mark.parametrize("shape, dim", [ + ((1024, ), 0), + ((128, 256), 0), + ((128, 256), 1), + ((64, 64, 64), 2), + ]) + def test_packing(self, shape, dim, device): + tensor = MXFP4Tensor(size=shape, device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=shape) + torch.testing.assert_close(tensor.data, unpacked) + + def test_packing_with_padding(self, device): + shape = (7, 5) + dim = 1 + tensor = MXFP4Tensor(size=shape, device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=shape) + torch.testing.assert_close(tensor.data, unpacked) + + def test_invalid_packing_dimension(self, device): + tensor = MXFP4Tensor(size=(4, 4), device=device).random() + with pytest.raises(AssertionError): + tensor.to_packed_tensor(dim=2) # Invalid dimension + + def test_empty_tensor(self, device): + tensor = MXFP4Tensor(torch.tensor([], device=device)) + assert tensor.to(torch.float32).numel() == 0 + + +class TestMXScaleTensor(MXBaseTest): + + def test_positive_values(self, device): + values = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) + data = MXScaleTensor(values) + torch.testing.assert_close(data.to(torch.float32), values) + + def test_special_values(self, device): + values = torch.tensor([0.0, -1.0, float('nan'), float('inf'), float('-inf')], device=device) + tensor = MXScaleTensor(values) + expected_data = torch.tensor([255, 255, 255, 255, 255], dtype=torch.uint8, device=device) + assert torch.equal(expected_data, tensor.data), "Special values should be encoded as NaN (255)" + + def test_e8m0_nan_to_float_nan(self, device): + tensor = MXScaleTensor(size=(1, ), device=device) + tensor.data = torch.tensor([255], device=device, dtype=torch.uint8) + assert torch.isnan(tensor.to(torch.float32)), "E8M0 NaN encoding should convert to float32 NaN" + + def test_random_generation(self, device): + data = MXScaleTensor(size=(1000, ), device=device).random() + data = data.data + assert ((data >= 0) & (data <= 254)).all(), "Generated data should be between 0 and 254" + assert (data != 255).all(), "Generated data should not include NaN encoding (255)" + + @pytest.mark.parametrize("K, N", [(64, 128), (128, 256)]) + def test_roundtrip(self, K, N, device): + tensor = MXScaleTensor(size=(K, N), device=device).random() + tensor2 = MXScaleTensor(tensor.to(torch.float32)) + torch.testing.assert_close(tensor.data, tensor2.data) diff --git a/third_party/mthreads/python/test/unit/language/test_random.py b/third_party/mthreads/python/test/unit/language/test_random.py new file mode 100644 index 0000000000..79a4e3842f --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_random.py @@ -0,0 +1,273 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + bits = np.dtype(self._dtype).itemsize * 8 + while len(res) < pad: + res.append(np.array((n & ((1 << bits) - 1)), dtype=self._dtype)) + n >>= bits + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK = tl.constexpr(1024) + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = PHILOX_32 + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +def test_seed_is_int(device): + + @triton.jit + def kernel(X, seed): + offset = tl.arange(0, 1) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand) + + x = torch.empty(1, dtype=torch.float32, device=device) + with pytest.raises(triton.compiler.errors.CompilationError): + seed0 = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](x, seed0) + with pytest.raises(triton.compiler.errors.CompilationError): + seed1 = 2.3 + kernel[(1, )](x, seed1) + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/mthreads/python/test/unit/language/test_reproducer.py b/third_party/mthreads/python/test/unit/language/test_reproducer.py new file mode 100644 index 0000000000..75ef14f8f5 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_reproducer.py @@ -0,0 +1,38 @@ +import triton +import re +import os + + +def test_triton_reproducer_path(monkeypatch, tmp_path): + # If we get a cache hit there will be no reproducer generated + monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1") + + @triton.jit + def triton_(): + return + + # We need an temp empty file for MLIR to write the reproducer to, and then + # the TRITON_REPRODUCER_PATH env var enables crash the reproduction + # generation in MLIR. + repro_path = tmp_path / "repro_prefix" + monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path)) + + # Run the kernel so MLIR will generate a crash reproducer. It doesn't really + # matter what the kernel does, just that the PassManager runs its passes. + triton_[(1, )]() + + stages = { + 'make_ttir': "triton-combine", + 'make_ttgir': "triton.*-coalesce", + 'make_llir': "convert-triton-.*gpu-to-llvm", + } + + for stage_name, stage_pipeline_check in stages.items(): + assert os.path.exists(str(repro_path) + '.' + stage_name + '.repro.mlir') + curr_repro_path = tmp_path / ("repro_prefix." + stage_name + ".repro.mlir") + repro = curr_repro_path.read_text() + assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {curr_repro_path}. Got:\n{repro}" + m = re.search(r"pipeline: \"(.*" + stage_pipeline_check + ".*)\"", repro) + assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer" + pipeline_str = m.group(1) + assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer" diff --git a/third_party/mthreads/python/test/unit/language/test_standard.py b/third_party/mthreads/python/test/unit/language/test_standard.py new file mode 100644 index 0000000000..2c4c7639ec --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_standard.py @@ -0,0 +1,193 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 1], [1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("k", [None, 8]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) +def test_sort(M, N, k, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k: tl.constexpr, + descending: tl.constexpr): + offs_m = tl.arange(0, M) + offs_x_n = tl.arange(0, N) + offs_z_n = offs_x_n if k is None else tl.arange(0, k) + offs_x = offs_m[:, None] * stride_xm + offs_x_n[None, :] + x = tl.load(X + offs_x) + if k is None or x.numel < k: + z = tl.sort(x, descending=descending) + else: + z = tl.topk(x, k) + offs_z = offs_m[:, None] * stride_zm + offs_z_n[None, :] + tl.store(Z + offs_z, z) + + z_shape = (M, N if k is None else k) + x = numpy_random((M, N), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + z = torch.empty(z_shape, dtype=x.dtype, device=x.device) + if k is None or x.numel() < k: + y = torch.sort(x, descending=descending)[0] + else: + try: + y = torch.topk(x, k=k).values + except RuntimeError as exc: + # MUSA torch.topk currently rejects integral dtypes even though + # Triton's bitonic top-k path handles them correctly. + if device == "musa" and "Dtype of input tensor of topk only support" in str(exc): + y = torch.topk(x.cpu(), k=k).values.to(device) + else: + raise + sort_kernel[(1, )](x, x.stride(0), z, z.stride(0), M, N, k, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [[1, 16, 64], [8, 2, 256], [32, 1, 2], [128, 8, 1]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) +@pytest.mark.parametrize("dim", [0, 1, 2, -2]) +def test_flip(M, N, K, dtype_str, dim, device): + + @triton.jit + def flip_kernel(X, Z, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, dim: tl.constexpr): + offx = tl.arange(0, M) * N * K + offy = tl.arange(0, N) * K + offz = tl.arange(0, K) + off3d = offx[:, None, None] + offy[None, :, None] + offz[None, None, :] + x = tl.load(X + off3d) + x = tl.flip(x, dim) + tl.store(Z + off3d, x) + + x = numpy_random((M, N, K), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (dim, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, M, N, K, dim, num_warps=8) + assert (y == z).all(), (y, z) + + +@pytest.mark.interpreter +def test_flip_inf(device): + # Reproducer for https://github.com/triton-lang/triton/issues/5439 + + @triton.jit + def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr): + pid = tl.program_id(0) + x = tl.load(x_ptr + pid * N + tl.arange(0, N)) + shape: tl.constexpr = (N // 2, 2) + y = x.reshape(shape) + y = tl.flip(y, dim=1).reshape(x.shape) + tl.store(out_ptr + pid * N + tl.arange(0, N), y) + + x = torch.arange(0, 16, device=device).unsqueeze(0).float() + x[:, -1] = float('inf') + + expect = x.reshape(-1, 8, 2).flip(-1).reshape(-1, 16) + actual = torch.empty_like(x) + triton_flip_kernel[(x.shape[0], )](actual, x, x.shape[1]) + + torch.testing.assert_close(expect, actual) + + +@pytest.mark.interpreter +def test_ravel(device): + + @triton.jit + def triton_ravel(out_ptr): + a = tl.arange(0, 256) + a = tl.reshape(a, (32, 8)) + a = tl.ravel(a) + tl.store(out_ptr + tl.arange(0, 256), a) + + out = torch.empty((256, ), device=device, dtype=torch.int32) + triton_ravel[(1, )](out) + + assert (out == torch.arange(0, 256, device=device)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]]) +def test_swizzle2d(size_i, size_j, size_g, device): + + @triton.jit + def swizzle2d_kernel(output, size_i, size_j, size_g): + for i in tl.range(0, size_i, 1): + for j in tl.range(0, size_j, 1): + new_i, new_j = tl.swizzle2d(i, j, size_i, size_j, size_g) + tl.store(output + new_i * size_j + new_j, i * size_j + j) + + output = torch.zeros(size_i, size_j).to(device) + swizzle2d_kernel[(1, )](output, size_i, size_j, size_g) + expected_order = torch.tensor([[0, 3, 6, 9, 12, 15, 18], [1, 4, 7, 10, 13, 16, 19], [2, 5, 8, 11, 14, 17, 20], + [21, 23, 25, 27, 29, 31, 33], [22, 24, 26, 28, 30, 32, 34]]).to(device) + assert (output == expected_order).all(), (output, expected_order) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, dim", [((1, 2, 4), 0), ((2, 1, 4), 1), ((2, 4, 1), 2)]) +def test_squeeze(shape, dim, device): + + @triton.jit + def triton_squeeze(out_ptr, dim: tl.constexpr, s0: tl.constexpr, s1: tl.constexpr, s2: tl.constexpr): + a = tl.arange(0, 8) + a = tl.reshape(a, (s0, s1, s2)) + a = tl.squeeze(a, dim) + a = tl.ravel(a) + tl.store(out_ptr + tl.arange(0, 8), a) + + out = torch.empty((8, ), device=device, dtype=torch.int32) + triton_squeeze[(1, )](out, dim, shape[0], shape[1], shape[2]) + + expected = torch.arange(0, 8, device=device, dtype=torch.int32) + expected = expected.reshape(shape).squeeze(dim).reshape(-1) + assert (out == expected).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dim", [0, 1, 2]) +def test_unsqueeze(dim, device): + + @triton.jit + def triton_unsqueeze(out_ptr, dim: tl.constexpr): + a = tl.arange(0, 8) + a = tl.reshape(a, (2, 4)) + a = tl.unsqueeze(a, dim) + a = tl.ravel(a) + tl.store(out_ptr + tl.arange(0, 8), a) + + out = torch.empty((8, ), device=device, dtype=torch.int32) + triton_unsqueeze[(1, )](out, dim) + + expected = torch.arange(0, 8, device=device, dtype=torch.int32) + expected = expected.reshape(2, 4).unsqueeze(dim).reshape(-1) + assert (out == expected).all() diff --git a/third_party/mthreads/python/test/unit/language/test_tuple.py b/third_party/mthreads/python/test/unit/language/test_tuple.py new file mode 100644 index 0000000000..49a2e52e11 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_tuple.py @@ -0,0 +1,359 @@ +import pytest +import triton +import triton.language as tl +from typing import NamedTuple +import torch + + +@triton.jit +def _tuple_increment(values): + return tl.tuple([v + 1 for v in values]) + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1, _ = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +@pytest.mark.interpreter +def test_assign(device): + vals = (2., 3., None) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + + +@triton.jit +def _tuple_ret(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_assign_return(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = _tuple_ret(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +# ------- + + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.static_assert(tuple1[1] is None) + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[2][0])) + tl.store(Ptr + 8, tuple1[2][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) + + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.static_assert(N1 is None) + tl.static_assert(tuple1[1][1] is None) + tl.static_assert(tuple1[1][3] == 4) + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][2])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, None, tuple1)) + + +@pytest.mark.interpreter +def test_serialize(device): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10, ), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, )) + ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) + assert torch.equal(z, ref) + + +class Function(NamedTuple): + fn: tl.constexpr + captured: tuple + + +class Tensor(NamedTuple): + ptr: any + shape: tuple + stride: tuple + + +@triton.jit +def _namedtuple_create_func0(shape, ptr, stride): + return Tensor(shape=shape, ptr=ptr, stride=stride) + + +@triton.jit +def _namedtuple_create_func1(shape, ptr, stride): + tensor = Tensor(shape=shape, ptr=ptr, stride=stride) + return tensor + + +@triton.jit +def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1]) + return mask + + +@triton.jit +def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride) + Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride) + Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] + Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] + x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) + y = closure.fn(x, *closure.captured) + tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N)) + + +@pytest.mark.interpreter +def test_namedtuple(device): + x = torch.randn((32, 32), dtype=torch.float32, device=device) + y = torch.empty((16, 16), dtype=torch.float32, device=device) + a = torch.tensor([5.2], dtype=torch.float32, device=device) + + @triton.jit + def mul(x, a): + return x * tl.load(a) + + function = Function(mul, (a, )) + tx = Tensor(x, x.shape, x.stride()) + ty = Tensor(y, y.shape, y.stride()) + _namedtuple_kernel[(1, )](function, tx, ty, 64, 64) + assert torch.allclose(y, x[:16, :16] * a) + + +@pytest.mark.interpreter +def test_eq(device): + + @triton.jit + def fn(ret_ptrs): + tl.store(ret_ptrs + 0, (1, 2) == (1, 2)) + tl.store(ret_ptrs + 1, (1, 2) == (1, 1)) + tl.store(ret_ptrs + 2, tl.tuple((1, 2)) == (1, 2)) + tl.store(ret_ptrs + 3, tl.tuple((1, 2)) == (1, 3)) + + rets = torch.zeros((4, ), dtype=torch.int32, device=device) + fn[(1, )](rets) + assert rets[0].item() == 1 + assert rets[1].item() == 0 + assert rets[2].item() == 1 + assert rets[3].item() == 0 + + +@pytest.mark.interpreter +def test_add(device): + + @triton.jit + def fn(ret_ptrs): + tuple0 = ((0, 1)) + (2, 3) + for i in tl.static_range(4): + tl.store(ret_ptrs + i, tuple0[i]) + tuple1 = tl.tuple((4, 5)) + (6, 7) + for i in tl.static_range(4): + tl.store(ret_ptrs + 4 + i, tuple1[i]) + + rets = torch.zeros((8, ), dtype=torch.int32, device=device) + fn[(1, )](rets) + torch.testing.assert_close(rets.cpu(), torch.arange(8, dtype=torch.int32)) + + +def test_passing_tuple_with_constexpr(device): + + @triton.jit + def m_to_the_n(X, shape: tl.constexpr, strides, m_n): + Xs = X + tl.arange(0, shape[0])[:, None] * strides[0] + tl.arange(0, shape[1])[None, :] * strides[1] + # Include a for loop to ensure strides[1] is lifted into a constexpr + # (otherwise cloning the local scope will fail). + data = tl.load(Xs) + for i in tl.range(0, m_n[1]): + data = m_n[0] * data + tl.store(Xs, data) + + x = torch.arange(0, 64, device=device).reshape(8, 8) + expected_x = 8 * x.clone() + m_to_the_n[(1, )](x, x.shape, x.stride(), (2, 3)) + torch.testing.assert_close(x, expected_x, rtol=0, atol=0) + + +@triton.jit +def _nested_tuple_kernel(x): + # This creates a new scope, which will force a copy of liveins. It's + # important for this to happen as it forces IR flattening/unflattening, + # which relies on the types being correct for the roundtrip to succeed. + for _ in range(1): + tl.static_assert(x[1][0] == 2) + + +def test_passing_nested_tuple_with_constexpr(device): + _nested_tuple_kernel[(1, )](((1, ), (tl.constexpr(2), ))) + + +def test_passing_nested_tuple_with_constexpr_and_jit_hook(device, fresh_knobs): + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + fresh_knobs.runtime.jit_cache_hook = cache_hook + + device = getattr(torch, device).current_device() + + # Clear the existing cache for this device to ensure that the hook is called; + # This is needed because the kernel is shared between multiple tests and may + # already have been compiled for this device. + _nested_tuple_kernel.device_caches[device][0].clear() + + warmup_run = _nested_tuple_kernel.warmup(((1, ), (tl.constexpr(2), )), grid=(1, )) + assert warmup_run is not None + + assert specialization_data is not None + + preload_run = _nested_tuple_kernel.preload(specialization_data) + assert preload_run is not None + + assert warmup_run.hash == preload_run.hash + + +def test_modifying_tuples(): + + @triton.jit + def set_tuple_value_at_idx(): + t = tl.tuple([5, 6, 7]) + t[0] = 0 + + with pytest.raises(triton.CompilationError): + set_tuple_value_at_idx[(1, )]() + + +@pytest.mark.interpreter +def test_tuple_logic(): + + @triton.jit + def tuple_logic_kernel(): + + # arity-2 BoolOps: + tl.static_assert(((3, 4) or (5, 6)) == (3, 4)) + tl.static_assert(((3, 4) and (5, 6)) == (5, 6)) + tl.static_assert(((3, 4) and ()) == ()) + tl.static_assert((() or (5, 6)) == (5, 6)) + + # arity-3 BoolOps: + tl.static_assert(((1, 2) and (3, 4) and (5, 6)) == (5, 6)) + tl.static_assert(((1, 2) or (3, 4) or (5, 6)) == (1, 2)) + + # constexpr short-circuiting over dynamic argument: + tl.static_assert((() and tl.program_id(0)) == ()) + + tuple_logic_kernel[(1, )]() + + +@pytest.mark.interpreter +def test_tuple_float(): + + @triton.jit + def _namedtuple_float_tuple_kernel(): + x, y = float("-inf"), float("inf") # noqa: F841 + + _namedtuple_float_tuple_kernel[(1, )]() + + +@triton.constexpr_function +def passthrough_constexpr(x): + return x + + +class TrivialTuple(NamedTuple): + foo: tl.constexpr + + +@pytest.mark.interpreter +def test_tuple_constexpr_function(): + + @triton.jit + def kernel(): + tl.static_assert(passthrough_constexpr(TrivialTuple(0)).foo == 0) + + kernel[(1, )]() diff --git a/third_party/mthreads/python/test/unit/runtime/test_autotuner.py b/third_party/mthreads/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 0000000000..f413d77e30 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,436 @@ +import torch + +import triton +import triton.language as tl +import pytest + +import pathlib +import uuid +from triton._internal_testing import is_cuda + + +def do_bench(kernel_call, quantiles, use_cuda_graph=False): + if use_cuda_graph: + return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles) + return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool, device: str): + if use_cuda_graph and not torch.cuda.is_available(): + pytest.xfail("CUDA is not available") + + M, N = 1024, 16 + src = torch.randn(M * N, device=device) + dst = torch.empty(M * N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] + + @triton.autotune(configs=configs, key=["M"], + do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph)) + @triton.jit + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) + offsets_n = tl.arange(0, BLOCK_SIZE_N) + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) + _kernel[grid](dst, src, N, M, N) + # the key word args could be in arbitrary order. + _kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N) + assert len(_kernel.cache) == 2 + + +def test_no_do_bench(device: str): + M, N = 1024, 16 + src = torch.randn(M * N, device=device) + dst = torch.empty(M * N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] + + @triton.autotune(configs=configs, key=["M"]) + @triton.jit + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) + offsets_n = tl.arange(0, BLOCK_SIZE_N) + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) + _kernel[grid](dst, src, N, M, N) + assert len(_kernel.cache) == 1 + + +@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True]) +def test_restore(pass_kwargs_to_kernel, device): + N = 1024 + src = torch.zeros(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + if pass_kwargs_to_kernel: + _kernel[grid](src=src, N=N) + else: + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool, device: str): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, + reason="Requires compute capability >= 9 for NV") +def test_override_ttir(device): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + + ir_src = r""" +module { + tt.func public @_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+01> : tensor<32xf32> + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = arith.mulf %9, %cst : tensor<32xf32> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10, %6 : tensor<32x!tt.ptr> + tt.return + } +} + """ + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttir") + temp_file.write_text(ir_src) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + + # Change the behavior of kernel by overriding PTX + torch.testing.assert_close(src * 10, dst) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, + reason="Requires compute capability >= 9 for NV") +def test_override_ttgir(device): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + + ir_src = r""" +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+01> : tensor<32xf32, #blocked> + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<32xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<32xi32, #blocked> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32, #blocked> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr, #blocked>, tensor<32xi32, #blocked> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr, #blocked> + %10 = arith.mulf %9, %cst : tensor<32xf32, #blocked> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr, #blocked>, tensor<32xi32, #blocked> + tt.store %12, %10, %6 : tensor<32x!tt.ptr, #blocked> + tt.return + } +} + """ + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ttgir") + temp_file.write_text(ir_src) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + + # Change the behavior of kernel by overriding PTX + torch.testing.assert_close(src * 10, dst) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, + reason="PTX file in this unit test is only for SM90") +def test_override_ptx(device): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + + ir_src = r""" +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl _kernel // -- Begin function _kernel + // @_kernel +.visible .entry _kernel( + .param .u64 .ptr .global .align 1 _kernel_param_0, + .param .u64 .ptr .global .align 1 _kernel_param_1, + .param .u32 _kernel_param_2, + .param .u64 .ptr .global .align 1 _kernel_param_3 +) +.reqntid 128 +{ + .reg .pred %p<4>; + .reg .b32 %r<10>; + .reg .b32 %f<3>; + .reg .b64 %rd<6>; + .loc 1 180 0 +$L__func_begin0: + .loc 1 180 0 + +// %bb.0: + ld.param.u64 %rd3, [_kernel_param_0]; + ld.param.u64 %rd4, [_kernel_param_1]; +$L__tmp0: + .loc 1 181 28 + mov.u32 %r3, %ctaid.x; + .loc 1 181 33 + shl.b32 %r4, %r3, 5; + ld.param.u32 %r5, [_kernel_param_2]; + .loc 1 181 59 + mov.u32 %r6, %tid.x; + and.b32 %r7, %r6, 31; + .loc 1 181 46 + or.b32 %r8, %r4, %r7; + .loc 1 182 46 + setp.lt.s32 %p1, %r8, %r5; + .loc 1 182 22 + mul.wide.s32 %rd5, %r8, 4; + add.s64 %rd1, %rd4, %rd5; + .loc 1 182 16 + // begin inline asm + mov.u32 %r1, 0x0; + @%p1 ld.global.b32 { %r1 }, [ %rd1 + 0 ]; + // end inline asm + mov.b32 %f1, %r1; + .loc 1 183 12 + mul.f32 %f2, %f1, 0f41200000; + .loc 1 184 19 + add.s64 %rd2, %rd3, %rd5; + .loc 1 184 28 + and.b32 %r9, %r6, 96; + setp.eq.s32 %p3, %r9, 0; + mov.b32 %r2, %f2; + and.pred %p2, %p3, %p1; + // begin inline asm + @%p2 st.global.b32 [ %rd2 + 0 ], { %r2 }; + // end inline asm + .loc 1 184 4 + ret; +$L__tmp1: +$L__func_end0: + // -- End function +} + """ + temp_file = pathlib.Path(f"/tmp/test_override_{str(uuid.uuid4())}.ptx") + temp_file.write_text(ir_src) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32, 'ir_override': str(temp_file)})] + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + x = x * 10 + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + + # Change the behavior of kernel by overriding PTX + torch.testing.assert_close(src * 10, dst) + + +def test_exceed_tmem(device): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10: + pytest.skip("Test requires tensor memory.") + N = 512 + dst = torch.empty((N, ), device=device, dtype=torch.float32) + configs = [triton.Config(kwargs={'BLOCK_SIZE': 128}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + exception_out_of_resource = None + + def _post_hook(*args, exception): + nonlocal exception_out_of_resource + if exception is not None: + exception_out_of_resource = exception + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=None, post_hook=_post_hook) + @triton.jit + def dot_kernel(dst, BLOCK_SIZE: tl.constexpr): + a = tl.full((BLOCK_SIZE, BLOCK_SIZE), 0.0, tl.float16) + b = tl.full((BLOCK_SIZE, BLOCK_SIZE), 0.0, tl.float16) + c0 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c1 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c2 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c3 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c4 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + for i in range(0, 100): + c0 = tl.dot(a, b, c0) + c1 = tl.dot(a, b, c1) + c2 = tl.dot(a, b, c2) + c3 = tl.dot(a, b, c3) + c4 = tl.dot(a, b, c4) + c = c4 + c3 + c2 + c1 + c0 + c = c.reshape([BLOCK_SIZE * BLOCK_SIZE]) + tl.store(dst + tl.arange(0, BLOCK_SIZE * BLOCK_SIZE), c) + + dot_kernel[(1, )](dst) + assert exception_out_of_resource is not None and str( + exception_out_of_resource + ) == "out of resource: tensor memory, Required: 640, Hardware limit: 512. Reducing block sizes or `num_stages` may help." + + +def test_exceed_threads(device): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + x = torch.empty(1024, device=device, dtype=torch.float32) + y = torch.empty_like(x) + output = torch.empty_like(x) + + configs = [ + triton.Config({}, num_warps=128), + triton.Config({}, num_warps=4), + ] + + exception_out_of_resource = None + + def _post_hook(*args, exception): + nonlocal exception_out_of_resource + if exception is not None: + exception_out_of_resource = exception + + @triton.autotune(configs=configs, key=['BLOCK_SIZE'], do_bench=do_bench, post_hook=_post_hook) + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + def grid(meta): + return (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + + add_kernel[grid](x, y, output, x.numel(), BLOCK_SIZE=128) + + warp_size = triton.runtime.driver.active.get_current_target().warp_size + assert exception_out_of_resource is not None and f"out of resource: threads, Required: {128 * warp_size}" in str( + exception_out_of_resource) + + +def test_prune_all_configs(device): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + + def early_config_prune(configs, named_args, **kwargs): + return [] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + try: + _kernel[grid](dst, src, N=N) + pytest.fail("Expected exception was not thrown.") + except triton.TritonError as e: + assert e is not None and str( + e + ) == "Autotuner error: No valid autotuner configs after pruning. `early_config_prune` should return at least one config." diff --git a/third_party/mthreads/python/test/unit/runtime/test_bindings.py b/third_party/mthreads/python/test/unit/runtime/test_bindings.py new file mode 100644 index 0000000000..de9c1dc9c7 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,112 @@ +import triton +import triton.language as tl + +import torch +import math + +_BLOCK_SIZE = 16 + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(device): + """ + Test the MLIR bindings exposed for the out-of-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + if name == "tt.make_range": + assert 0 == op.get_int_attr("start") + assert _BLOCK_SIZE == op.get_int_attr("end") + if name == "arith.constant": + val = op.get_int_attr("value") + assert val is None or isinstance(val, int) + + kernel = add_kernel + args = [ + torch.empty((32, 32), device=device), # in_ptr0 + torch.empty((32, 32), device=device), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device=device), # out_ptr + _BLOCK_SIZE, # BLOCK_SIZE + ] + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={kernel.arg_names[i]: triton.runtime.jit.mangle_type(arg) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + ) + + context = triton._C.libtriton.ir.context() + options = backend.parse_options(dict()) + codegen_fns = dict() + module_map = backend.get_module_map() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) + ttir_module.walk(walk_fn) + + +def test_python_func_in_visit_call(device): + + @triton.jit + def test_py_call_const_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + log2e: tl.constexpr = math.log2(math.e) + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * log2e + tl.store(out_ptr + offsets, output, mask=mask) + + x = torch.randn(4, device=device) + out = torch.zeros_like(x) + test_py_call_const_kernel[(4, )](x, out, 4, 4) diff --git a/third_party/mthreads/python/test/unit/runtime/test_blaslt.py b/third_party/mthreads/python/test/unit/runtime/test_blaslt.py new file mode 100644 index 0000000000..850efe7be1 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_blaslt.py @@ -0,0 +1,196 @@ +import pytest +import torch +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4 +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor + + +def supports_block_scaling(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 10 + + +@pytest.mark.parametrize("m, n, k", [(16, 16, 16), (32, 16, 16), (16, 32, 16), (16, 16, 32)]) +@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) +def test_blaslt(m, n, k, dtype_str, device): + dtype = getattr(torch, dtype_str) + + if is_cuda(): + from triton._C.libtriton import nvidia as vendor + if dtype_str == "float8_e4m3fnuz": + pytest.skip("float8_e4m3fnuz is not supported on CUDA") + if dtype == torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("fp8 is only supported on CUDA with cc >= 90") + c_dtype = dtype + make_handle = lambda workspace: vendor.cublas.CublasLt(workspace) + elif is_hip(): + from triton._C.libtriton import amd as vendor + if dtype_str == "float8_e4m3fnuz" and not is_hip_cdna3(): + pytest.skip("float8_e4m3fnuz is only supported on HIP CDNA3") + if dtype_str == "float8_e4m3fn" and not is_hip_cdna4(): + pytest.skip("float8_e4m3fn is only supported on HIP CDNA4") + c_dtype = torch.float16 if dtype_str in ("float8_e4m3fnuz", "float8_e4m3fn") else dtype + make_handle = lambda workspace: vendor.hipblas.HipblasLt(workspace) + else: + pytest.skip("test_blaslt is only supported on CUDA or HIP") + + torch.manual_seed(123) + workspace_size = 32 * 1024 * 1024 + + def limited_rand(elements, shape): + total_elems = torch.prod(torch.tensor(shape)).item() + indices = torch.randint(0, len(elements), (total_elems, ), device=device) + return elements[indices].view(shape) + + elements = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32, device=device) + a = limited_rand(elements, (m, k)).to(dtype) + b = limited_rand(elements, (k, n)).to(dtype) + + c = torch.zeros((m, n), dtype=c_dtype, device=device) + + b = b.T.contiguous() + + workspace = torch.empty(workspace_size, dtype=torch.int8, device=device) + handle = make_handle(workspace) + + handle.matmul(a, b, c) + + ref = torch.matmul(a.to(torch.float16), b.to(torch.float16).T) + + assert torch.allclose(c.to(torch.float16), ref, atol=2.0) + + +@pytest.mark.parametrize("m, n, k", [(256, 256, 512), (512, 512, 512), (1024, 1024, 1024)]) +def test_block_scaled_matmul_mxfp8(m, n, k, device): + """Test block-scaled matmul with MXFP8 format (FP8 E4M3 inputs, E8M0 scales).""" + if not is_cuda(): + pytest.skip("block_scaled_matmul is only supported on CUDA") + if not supports_block_scaling(): + pytest.skip("block_scaled_matmul requires compute capability 10.0 (Blackwell)") + + from triton._C.libtriton import nvidia + + torch.manual_seed(42) + + # Constants for MXFP8 + VEC_SIZE = 32 # 32-element groups for E8M0 scales + + # Create workspace and cuBLAS handle + workspace_size = 32 * 1024 * 1024 + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + handle = nvidia.cublas.CublasLt(workspace) + + # Generate random FP8 inputs + a_fp32 = torch.randn(m, k, device=device, dtype=torch.float32) + b_fp32 = torch.randn(n, k, device=device, dtype=torch.float32) + + # Convert to FP8 E4M3 + a = a_fp32.to(torch.float8_e4m3fn) + b = b_fp32.to(torch.float8_e4m3fn) + + # Generate scales in the expected 4D layout, then reshape to 5D and flatten + # Scale shape: [M // 128, K // VEC_SIZE // 4, 32, 16] + a_scale_shape = [m // 128, k // VEC_SIZE // 4, 32, 16] + b_scale_shape = [n // 128, k // VEC_SIZE // 4, 32, 16] + + epsilon = 1e-8 + a_scale_raw = torch.rand(a_scale_shape, device=device) + epsilon + b_scale_raw = torch.rand(b_scale_shape, device=device) + epsilon + + # Convert to MXScaleTensor (E8M0 format) + a_scale_mx = MXScaleTensor(a_scale_raw) + b_scale_mx = MXScaleTensor(b_scale_raw) + a_scale = a_scale_mx.data + b_scale = b_scale_mx.data + + # Reshape to 5D for TMA and flatten for cuBLAS + a_scale_5d = a_scale.reshape(1, a_scale_shape[0], a_scale.shape[1], 2, 256) + b_scale_5d = b_scale.reshape(1, b_scale_shape[0], b_scale.shape[1], 2, 256) + a_scale_cublas = a_scale_5d.contiguous().flatten() + b_scale_cublas = b_scale_5d.contiguous().flatten() + + # Prepare output tensor + output = torch.empty((m, n), dtype=torch.float16, device=device) + + # Call cuBLAS block-scaled matmul + handle.block_scaled_matmul_mxfp8(a, b, output, a_scale_cublas, b_scale_cublas) + + # Compute reference using PyTorch + def unpack_scale(packed): + packed = packed.reshape(*packed.shape[:-2], 32, 4, 4) + num_chunk_m, num_chunk_k, _, _, _ = packed.shape + return packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous() + + a_scale_ref = a_scale_mx.to(torch.float32) + b_scale_ref = b_scale_mx.to(torch.float32) + a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:m, :k] + b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:k, :n] + + ref = torch.matmul(a.to(torch.float32) * a_scale_ref, b.to(torch.float32).T * b_scale_ref) + + torch.testing.assert_close(output.to(torch.float32), ref, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("m, n, k", [(256, 256, 512), (512, 512, 512), (1024, 1024, 1024)]) +def test_block_scaled_matmul_nvfp4(m, n, k, device): + """Test block-scaled matmul with NVFP4 format (packed FP4 inputs, FP8 E4M3 scales).""" + if not is_cuda(): + pytest.skip("block_scaled_matmul is only supported on CUDA") + if not supports_block_scaling(): + pytest.skip("block_scaled_matmul requires compute capability 10.0 (Blackwell)") + + from triton._C.libtriton import nvidia + + torch.manual_seed(42) + + # Constants for NVFP4 + VEC_SIZE = 16 # 16-element groups for FP8 E4M3 scales + + # Create workspace and cuBLAS handle + workspace_size = 32 * 1024 * 1024 + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + handle = nvidia.cublas.CublasLt(workspace) + + # Generate random MXFP4 tensors + a_ref = MXFP4Tensor(size=(m, k), device=device).random() + b_ref = MXFP4Tensor(size=(n, k), device=device).random() + + # Pack two FP4 elements per byte along K dimension + a = a_ref.to_packed_tensor(dim=1) # (M, K//2) in uint8 + b = b_ref.to_packed_tensor(dim=1) # (N, K//2) in uint8 + + # Generate scales in the expected 4D layout + # Scale shape: [M // 128, K // VEC_SIZE // 4, 32, 16] + a_scale_shape = [m // 128, k // VEC_SIZE // 4, 32, 16] + b_scale_shape = [n // 128, k // VEC_SIZE // 4, 32, 16] + + epsilon = 1e-8 + a_scale_raw = torch.rand(a_scale_shape, device=device) + epsilon + b_scale_raw = torch.rand(b_scale_shape, device=device) + epsilon + + # For NVFP4, scales are FP8 E4M3 + a_scale = a_scale_raw.to(torch.float8_e4m3fn) + b_scale = b_scale_raw.to(torch.float8_e4m3fn) + + # Flatten for cuBLAS (use original 4D layout, not 5D reshaped) + a_scale_cublas = a_scale.contiguous().flatten() + b_scale_cublas = b_scale.contiguous().flatten() + + # Prepare output tensor + output = torch.empty((m, n), dtype=torch.float16, device=device) + + # Call cuBLAS block-scaled matmul + handle.block_scaled_matmul_nvfp4(a, b, output, a_scale_cublas, b_scale_cublas) + + # Compute reference using PyTorch + def unpack_scale(packed): + packed = packed.reshape(*packed.shape[:-2], 32, 4, 4) + num_chunk_m, num_chunk_k, _, _, _ = packed.shape + return packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous() + + a_scale_ref = a_scale.to(torch.float32) + b_scale_ref = b_scale.to(torch.float32) + a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:m, :k] + b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:k, :n] + + ref = torch.matmul(a_ref.to(torch.float32) * a_scale_ref, b_ref.to(torch.float32).T * b_scale_ref) + + torch.testing.assert_close(output.to(torch.float32), ref, atol=1e-1, rtol=1e-1) diff --git a/third_party/mthreads/python/test/unit/runtime/test_build.py b/third_party/mthreads/python/test/unit/runtime/test_build.py new file mode 100644 index 0000000000..6e3b7cbb69 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_build.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest +import tempfile + +from pathlib import Path + +import triton + +from triton.runtime.build import compile_module_from_src + +TEST_MODULE_C = """ +#include +#include + +static PyObject* go(PyObject* self, PyObject* args) { + const char *command; + if (!PyArg_ParseTuple(args, "s", &command)) + return NULL; + + const char* res; + if (strcmp(command, "hello") == 0) { + res = "hiya"; + } else { + res = "huh"; + } + return PyUnicode_FromString(res); +} + +static PyMethodDef ModuleMethods[] = { + {"go", go, METH_VARARGS, "test_module.go for testing"}, + {NULL, NULL, 0, NULL} +}; + +static struct PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, + "test_module", + NULL, //documentation + -1, //size + ModuleMethods +}; + +PyMODINIT_FUNC PyInit_test_module(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} +""" + + +def test_compile_module(fresh_triton_cache): + mod = compile_module_from_src(TEST_MODULE_C, "test_module") + + with pytest.raises(Exception): + mod.go() + + assert mod.go("huh") == "huh" + assert mod.go("hello") == "hiya" + + # Make sure the module is cached + mod2 = compile_module_from_src(TEST_MODULE_C, "test_module") + assert mod2.__file__ == mod.__file__ + + +def test_compile_module_bad_cache(fresh_knobs): + with tempfile.TemporaryDirectory() as tmpd: + tmp = Path(tmpd) + called_get_file = False + + class InvalidFileCacheManager(triton.runtime.cache.FileCacheManager): + + def get_file(self, filename: str) -> str | None: + nonlocal called_get_file + called_get_file = True + (tmp / filename).write_text("not an so") + return str(tmp / filename) + + # First corrupt the cache + fresh_knobs.cache.manager_class = InvalidFileCacheManager + + mod = compile_module_from_src(TEST_MODULE_C, "test_module") + assert called_get_file + + with pytest.raises(Exception): + mod.go() + + assert mod.go("huh") == "huh" + assert mod.go("hello") == "hiya" diff --git a/third_party/mthreads/python/test/unit/runtime/test_cache.py b/third_party/mthreads/python/test/unit/runtime/test_cache.py new file mode 100644 index 0000000000..a5c66382a4 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_cache.py @@ -0,0 +1,893 @@ +import expecttest +import importlib.util +import itertools +import os +import re +import shutil +import pathlib +from concurrent.futures import Executor, Future, ThreadPoolExecutor + +import pytest +import torch + +import triton +import triton.language as tl +from triton._internal_testing import is_hip + + +@triton.jit +def function_0(i): + return i + 1 + + +@triton.jit +def function_1(i): + i = i + 1 + cond: tl.constexpr = True + if cond: + FN: tl.constexpr = function_2 + else: + FN: tl.constexpr = function_0 + return FN(i) + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize_on_alignment=["i"]) +def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new, to_modify): + kernel.hash = None + function_0.hash = None + function_1.hash = None + function_2.hash = None + to_modify._unsafe_update_src(to_modify.src.replace(old, new)) + ret = target.cache_key + to_modify._unsafe_update_src(to_modify.src.replace(new, old)) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1', function_1) + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_1) + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_2) + assert baseline != updated + + +def test_nested2_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_0) + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn._unsafe_update_src(orig_combine_fn_src.replace("COMBINE_OP", combine_op)) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan)) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn._unsafe_update_src(orig_combine_fn_src) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src) + + assert key not in seen_keys + seen_keys.add(key) + + +@triton.constexpr_function +def constexpr_flag_fn(): + return False + + +@triton.jit +def constexpr_fn_user(out): + a: tl.constexpr = constexpr_flag_fn() + tl.store(out, a) + + +def test_constexpr_fn_change(): + baseline = constexpr_fn_user.cache_key + + orig_src = constexpr_flag_fn.src + new_src = orig_src.replace("False", "True") + constexpr_flag_fn._unsafe_update_src(new_src) + constexpr_fn_user.hash = None + updated = constexpr_fn_user.cache_key + assert baseline != updated + + constexpr_flag_fn._unsafe_update_src(orig_src) + constexpr_fn_user.hash = None + assert constexpr_fn_user.cache_key == baseline + + +@triton.constexpr_function +def invalid_constexpr_fn(): + return torch.cuda.get_device_capability() + + +def test_invalid_constexpr_fn(): + with pytest.raises(RuntimeError): + invalid_constexpr_fn.cache_key + + +def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def test_reuse(device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + triton.knobs.runtime.jit_cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment']) +def test_specialize(mode, device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + triton.knobs.runtime.jit_cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode] + target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(device): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + + device = getattr(torch, device).current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.device_caches[device][0]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(device): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = getattr(torch, device).current_device() + assert len(kernel.device_caches[device][0]) == 1 + + +GLOBAL_VAR = tl.constexpr(1) + + +def test_kernel_global_var_change(device): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = tl.constexpr(42) + kernel[(1, )]() + CONSTEXPR_GLOBAL = tl.constexpr(43) + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL = tl.constexpr(0) + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +def test_cache_closure(): + + def make_closure(cst): + + @triton.jit + def closure(): + tl.full((16, ), cst, dtype=tl.int32) + + return closure + + cst = tl.constexpr(42) + closure = make_closure(cst) + + closure[(1, )]() + cst.value = 43 + with pytest.raises(RuntimeError) as e: + closure[(1, )]() + + assert "cst has changed since we compiled this kernel, from constexpr[42] to constexpr[43]" in str(e.value) + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_constexpr_cache_invalidation_recreated(device): + + def test_run(val): + VAL = tl.constexpr(val) + + @triton.jit + def kernel(out): + tl.store(out, VAL) + + out = torch.zeros(1, device=device) + kernel[(1, )](out) + return out.item() + + assert test_run(123) == 123 + assert test_run(123) == 123 + assert test_run(1234) == 1234 + assert test_run(1234) == 1234 + + +def test_jit_warmup_cache(device) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + 32, + ] + device = getattr(torch, device).current_device() + assert len(kernel_add.device_caches[device][0]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + + +def test_jit_debug(device) -> None: + + @triton.jit + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") + + device = getattr(torch, device).current_device() + tmp = torch.tensor([1], dtype=torch.int32, device=device) + assert len(kernel.device_caches[device][0]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.device_caches[device][0]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.device_caches[device][0]) == 2 + bins = list(kernel.device_caches[device][0].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline(device) -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = getattr(torch, device).current_device() + assert len(kernel_add_device.device_caches[device][0]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.device_caches[device][0].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_preload(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = getattr(torch, device).current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + triton.knobs.runtime.jit_cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + shutil.rmtree(fresh_triton_cache) + kernel_add.device_caches[device][0].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.device_caches[device][0]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + triton.knobs.runtime.jit_cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert counter == 0 + assert len(kernel_add.device_caches[device][0]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) + + specialization_data_unknown_target = re.sub(r'("target"\s*:\s*\{[^{}]*"backend"\s*:\s*)"(.*?)"', + r'\1"unknown_target"', specialization_data, count=1) + + with pytest.raises(RuntimeError, match="Specialization data is for {'backend': 'unknown_target'"): + kernel_add.preload(specialization_data_unknown_target) + + +def test_hooks(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + # get the serialized specialization data + specialization_data = None + is_warmup = False + key = 0 + name = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + nonlocal is_warmup + is_warmup = kwargs["compile"]["is_warmup"] + nonlocal key + key = kwargs["compile"]["key"] + nonlocal name + name = kwargs["fn"].name + + specialization_data_compiled = None + + def compiled_hook(*args, **kwargs): + nonlocal specialization_data_compiled + specialization_data_compiled = kwargs["compile"]["specialization_data"] + + triton.knobs.runtime.jit_cache_hook = cache_hook + triton.knobs.runtime.jit_post_compile_hook = compiled_hook + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert specialization_data is not None and specialization_data_compiled == specialization_data + assert is_warmup is True + assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0] + assert name == "test_hooks..kernel_add" + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) +def test_within_2gb(device, fresh_triton_cache) -> None: + default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") + try: + use_buffer_ops_opts = ["1", "0"] + # The ranges should only be available when buffer ops are enabled + pointer_ranges = [[(0, )], []] + for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges): + # Set AMDGCN_USE_BUFFER_OPS + os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = [ + k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v + ] + + triton.knobs.runtime.jit_cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == pointer_range + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == pointer_range + finally: + os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops + + +def test_function_arguments(device): + + @triton.jit + def func1(): + return 1 + + @triton.jit + def func2(): + return 2 + + @triton.jit + def func3(x): + return x + + @triton.jit + def func4(x, y): + return x + y + + @triton.jit + def kernel(Y, fn: tl.constexpr, fn_args): + tl.store(Y, fn(*fn_args)) + + y = torch.zeros((5, ), dtype=torch.int32, device=device) + kernel[(1, )](y[0], func1, tuple()) + kernel[(1, )](y[1], func2, tuple()) + kernel[(1, )](y[2], func3, (3, )) + kernel[(1, )](y[3], func4, (3, 4)) + kernel[(1, )](y[4], func1, tuple()) + assert len(kernel.device_caches[0][0]) == 4 + assert y.tolist() == [1, 2, 3, 7, 1] + + +class MockThreadPool(Executor): + + def __init__(self): + self.work_queue = [] + + def submit(self, fn, *args, **kwargs): + future = Future() + + def task(): + if not future.set_running_or_notify_cancel(): + return + + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + + self.work_queue.append(task) + return future + + def run_one(self): + task = self.work_queue.pop(0) + task() + + def run_all(self): + while self.work_queue: + self.run_one() + + def shutdown(self, wait=True, *, cancel_futures=False): + self.run_all() + + +def test_async_compile_mock(device, fresh_triton_cache): + + @triton.jit + def kernel(Y, a: tl.constexpr): + tl.store(Y, a) + + with ( + MockThreadPool() as pool, + triton.AsyncCompileMode(pool), + ): + a = torch.empty((16, 16), device=device) + b = torch.empty((16, 16), dtype=torch.int32, device=device) + kernel.warmup(a, 0, grid=(1, )) + kernel.warmup(a, 1, grid=(1, )) + kernel.warmup(b, 0, grid=(1, )) + kernel.warmup(b, 1, grid=(1, )) + + # Nothing has actually compiled yet + assert len(kernel.device_caches[0][0]) == 4 + assert len(pool.work_queue) == 4 + + # Duplicates are only submitted once + kernel.warmup(a, 0, grid=(1, )) + kernel.warmup(a, 1, grid=(1, )) + assert len(kernel.device_caches[0][0]) == 4 + assert len(pool.work_queue) == 4 + + pool.run_one() + kernel[(1, )](a, 0) + assert len(kernel.device_caches[0][0]) == 4 + assert a[0, 0] == 0.0 + + pool.run_all() + + +def test_async_compile(device, fresh_triton_cache): + + @triton.jit + def kernel(Y, a: tl.constexpr): + tl.store(Y, a) + + with ( + ThreadPoolExecutor(2) as pool, + triton.AsyncCompileMode(pool), + ): + a = torch.empty((16, 16), device=device) + b = torch.empty((16, 16), dtype=torch.int32, device=device) + kernel.warmup(a, 0, grid=(1, )) + kernel.warmup(a, 1, grid=(1, )) + kernel.warmup(b, 0, grid=(1, )) + kernel.warmup(b, 1, grid=(1, )) + + assert len(kernel.device_caches[0][0]) == 4 + + kernel[(1, )](b, 1) + assert b[0, 0] == 1 + kernel[(1, )](b, 0) + assert b[0, 0] == 0 + kernel[(1, )](a, 0) + assert a[0, 0] == 0 + kernel[(1, )](a, 1) + assert a[0, 0] == 1 + kernel[(1, )](a, 2) + assert a[0, 0] == 2 + + +def test_higher_order_kernel(device, fresh_triton_cache, capsys): + + @triton.jit + def fn_a(): + tl.static_print("Compiling with fn_a") + return 0 + + @triton.jit + def kernel(out_ptr, FUNC: tl.constexpr) -> None: + val = FUNC() + tl.store(out_ptr, val) + + output = torch.empty((), device=device, dtype=torch.int32) + kernel[(1, )](output, fn_a) + assert output.item() == 0 + + # Test we can update src in-place + orig_src = fn_a.src + new_src = orig_src.replace("with fn_a", "with fn_a after modification") + new_src = new_src.replace("0", "1") + fn_a._unsafe_update_src(new_src) + kernel[(1, )](output, fn_a) + assert output.item() == 1 + + # Test that the on disc cache works + kernel.device_caches.clear() + kernel[(1, )](output, fn_a) + assert output.item() == 1 + + fn_a._unsafe_update_src(orig_src) + kernel[(1, )](output, fn_a) + assert output.item() == 0 + + expecttest.assert_expected_inline(capsys.readouterr().out, """\ +Compiling with fn_a +Compiling with fn_a after modification +""") + + +def test_preload_higher_order_kernels(device, fresh_triton_cache) -> None: + + @triton.jit + def fn_a(): + return 17 + + @triton.jit + def fn_b(): + return 31 + + @triton.jit + def kernel(out_ptr, FUNC: tl.constexpr) -> None: + val = FUNC() + tl.store(out_ptr, val) + + device = getattr(torch, device).current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + triton.knobs.runtime.jit_cache_hook = cache_hook + output = torch.empty((), device=device, dtype=torch.int32) + compiled_kernel = kernel[(1, )](output, fn_a) + assert output.item() == 17 + hash = compiled_kernel.hash + assert specialization_data is not None + + # clear the cache + shutil.rmtree(fresh_triton_cache) + kernel.device_caches[device][0].clear() + + # preload the kernel + kernel_preload = kernel.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel.device_caches[device][0]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + triton.knobs.runtime.jit_cache_hook = inc_counter + final_kernel = kernel[(1, )](output, fn_a) + assert counter == 0 + assert len(kernel.device_caches[device][0]) == 1 + assert final_kernel.hash == hash + + # different function should compile and not hit the cache + kernel[(1, )](output, fn_b) + assert counter == 1 + assert output.item() == 31 diff --git a/third_party/mthreads/python/test/unit/runtime/test_compilation_listener.py b/third_party/mthreads/python/test/unit/runtime/test_compilation_listener.py new file mode 100644 index 0000000000..18091a7ccd --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_compilation_listener.py @@ -0,0 +1,66 @@ +import triton +import triton.language as tl + +from triton.backends.compiler import GPUTarget +from triton.knobs import CompileTimes +from triton.compiler.compiler import ASTSource, IRSource + +from typing import Any, Union + +import torch + + +@triton.jit +def cumsum_kernel(ptr): + block = ptr + tl.arange(0, 4) + x = tl.load(block) + tl.store(block, tl.cumsum(x, 0)) + + +def test_compile_stats(device: str, fresh_knobs: Any, fresh_triton_cache: str) -> None: + captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None + + def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str], metadata_group: dict[str, Any], + times: CompileTimes, cache_hit: bool) -> None: + nonlocal captured + assert captured is None + captured = (src, metadata, metadata_group, times, cache_hit) + + fresh_knobs.compilation.listener = compile_listener + + x = torch.randn(4, device=device) + cumsum_kernel[(1, )](x) + + assert captured is not None + + # No cache hit at first + assert not captured[4] + + # Expected metadata + assert len(captured[1]["hash"]) > 0 + assert isinstance(captured[1]["target"], GPUTarget) + + # It in fact did take some time to do compilation + assert captured[3].ir_initialization > 0 + assert captured[3].total_lowering > 0 + assert captured[3].store_results > 0 + assert captured[3].total > 0 + + # Now lets create a new instance of the same kernel to pick up cache_hit=True + cumsum_kernel.device_caches.clear() + captured = None + cumsum_kernel[(1, )](x) + + assert captured is not None + # Cache hit! + assert captured[4] + + # Expected metadata + assert len(captured[1]["hash"]) > 0 + assert isinstance(captured[1]["target"], GPUTarget) + + # It in fact did take some time to do compilation + assert captured[3].ir_initialization > 0 + assert captured[3].total_lowering == 0 + assert captured[3].store_results == 0 + assert captured[3].total > 0 diff --git a/third_party/mthreads/python/test/unit/runtime/test_driver.py b/third_party/mthreads/python/test/unit/runtime/test_driver.py new file mode 100644 index 0000000000..4de8b017a9 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_driver.py @@ -0,0 +1,150 @@ +import sys +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace +import torch +import pytest + +import triton +import triton.language as tl + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + assert triton.runtime.driver._active is None + assert triton.runtime.driver._default is None + assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase")) + assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase")) + utils = triton.runtime.driver.active.utils # noqa: F841 + + +def test_kernel_in_thread(device): + # Test calling in a new thread sets a valid device context + buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device) + + @triton.jit + def _kernel(P, BLOCK: tl.constexpr): + pid = tl.program_id(0).to(tl.int64) + offset = pid * BLOCK + tl.arange(0, BLOCK) + + p = tl.load(P + offset) + tl.store(P + offset, p) + + def call_triton(): + N = buf.numel() + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), ) + _kernel[grid](buf, BLOCK=1024) + getattr(torch, device).synchronize() + + call_triton() + with ThreadPoolExecutor(1) as pool: + future = pool.submit(call_triton) + future.result() + + +def test_default_backend_env_selects_driver(monkeypatch): + from importlib import import_module + driver_mod = import_module("triton.runtime.driver") + + class FakeNvidiaDriver: + + @staticmethod + def is_active(): + return True + + class FakeMusaDriver: + + @staticmethod + def is_active(): + return True + + monkeypatch.setenv("TRITON_DEFAULT_BACKEND", "mthreads") + monkeypatch.setattr( + driver_mod, + "backends", + { + "nvidia": SimpleNamespace(driver=FakeNvidiaDriver), + "mthreads": SimpleNamespace(driver=FakeMusaDriver), + }, + ) + + selected = driver_mod._create_driver() + assert isinstance(selected, FakeMusaDriver) + + +def test_do_bench_device_type_selects_requested_driver(monkeypatch): + import triton.testing as testing + + counters = {"active_clear": 0, "musa_clear": 0, "fn": 0} + + class FakeEvent: + + def __init__(self, enable_timing=True): + self.enable_timing = enable_timing + + def record(self): + return None + + def elapsed_time(self, other): + return 1.0 + + class FakeDeviceInterface: + Event = FakeEvent + + @staticmethod + def synchronize(): + return None + + class FakeActiveDriver: + + def get_device_interface(self): + return FakeDeviceInterface() + + def get_empty_cache_for_benchmark(self): + return object() + + def clear_cache(self, cache): + counters["active_clear"] += 1 + + class FakeMusaDriver: + + @staticmethod + def is_active(): + return True + + def get_device_interface(self): + return FakeDeviceInterface() + + def get_empty_cache_for_benchmark(self): + return object() + + def clear_cache(self, cache): + counters["musa_clear"] += 1 + + monkeypatch.setattr(testing.runtime.driver, "_active", FakeActiveDriver(), raising=False) + monkeypatch.setattr( + testing, + "_available_backends", + {"mthreads": SimpleNamespace(driver=FakeMusaDriver)}, + ) + testing._get_backend_driver.cache_clear() + + def fn(): + counters["fn"] += 1 + + testing.do_bench(fn, warmup=1, rep=1, device_type="musa") + + assert counters["fn"] > 0 + assert counters["active_clear"] == 0 + assert counters["musa_clear"] > 0 + + +def test_do_bench_device_type_unknown_backend(monkeypatch): + import triton.testing as testing + + monkeypatch.setattr(testing, "_available_backends", {}) + testing._get_backend_driver.cache_clear() + + with pytest.raises(RuntimeError, match="Unsupported device_type/backend"): + testing.do_bench(lambda: None, warmup=1, rep=1, device_type="unknown") diff --git a/third_party/mthreads/python/test/unit/runtime/test_launch.py b/third_party/mthreads/python/test/unit/runtime/test_launch.py new file mode 100644 index 0000000000..e0f55b65ff --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_launch.py @@ -0,0 +1,234 @@ +import gc +import tracemalloc +import pytest +import pathlib +import os + +import torch +import triton +import triton.language as tl +from triton._internal_testing import is_cuda, is_hip, is_musa + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.knobs.runtime.launch_enter_hook.add(hook) + kernel[(1, 3, 2)](6) + triton.knobs.runtime.launch_enter_hook.remove(hook) + assert used_hook + + +def test_memory_leak(device) -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device=device) + out = torch.randn(10, device=device) + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +def test_load_hook() -> None: + + used_start_hook = False + start_hash = None + + def hook_start(module, function, name, metadata_group, hash): + nonlocal used_start_hook + nonlocal start_hash + start_hash = hash + used_start_hook = True + + used_end_hook = False + end_hash = None + + def hook_end(module, function, name, metadata_group, hash): + nonlocal used_end_hook + nonlocal end_hash + end_hash = hash + used_end_hook = True + + @triton.jit + def kernel(x): + pass + + # launch kernel + triton.knobs.runtime.kernel_load_start_hook.add(hook_start) + triton.knobs.runtime.kernel_load_end_hook.add(hook_end) + kernel[(1, 3, 2)](6) + assert used_start_hook + assert used_end_hook + assert start_hash == end_hash + triton.knobs.runtime.kernel_load_start_hook.remove(hook_start) + triton.knobs.runtime.kernel_load_end_hook.remove(hook_end) + + +def test_multiple_hooks() -> None: + + start0 = False + end0 = False + start1 = False + end1 = False + + def hook_start0(module, function, name, metadata_group, hash): + nonlocal start0 + start0 = True + + def hook_end0(module, function, name, metadata_group, hash): + nonlocal end0 + end0 = True + + def hook_start1(module, function, name, metadata_group, hash): + nonlocal start1 + start1 = True + + def hook_end1(module, function, name, metadata_group, hash): + nonlocal end1 + end1 = True + + triton.knobs.runtime.kernel_load_start_hook.add(hook_start0) + triton.knobs.runtime.kernel_load_end_hook.add(hook_end0) + triton.knobs.runtime.kernel_load_start_hook.add(hook_start1) + triton.knobs.runtime.kernel_load_end_hook.add(hook_end1) + + @triton.jit + def kernel(x): + pass + + kernel[(1, )](6) + + assert start0 + assert end0 + assert start1 + assert end1 + + triton.knobs.runtime.kernel_load_start_hook.remove(hook_start0) + triton.knobs.runtime.kernel_load_end_hook.remove(hook_end0) + triton.knobs.runtime.kernel_load_start_hook.remove(hook_start1) + triton.knobs.runtime.kernel_load_end_hook.remove(hook_end1) + + +@pytest.mark.parametrize("options", [ + {"num_warps": 1}, + {"enable_fp_fusion": False}, + {"extern_libs": {}}, +]) +def test_launch_with_options(options) -> None: + if "extern_libs" in options: + # copied from tutorials/07-extern-functions.py + current_dir = pathlib.Path(os.path.dirname(os.path.abspath(__file__))) + if is_cuda(): + libdir = current_dir.parent.parent.parent.parent / 'third_party/nvidia/backend/lib' + options["extern_libs"] = {"libdevice": str(libdir / 'libdevice.10.bc')} + elif is_hip(): + libdir = current_dir.parent.parent.parent.parent / 'third_party/amd/backend/lib' + options["extern_libs"] = {"ocml": str(libdir / 'ocml.bc'), "ockl": str(libdir / 'ockl.bc')} + elif is_musa(): + libdir = current_dir.parent.parent.parent.parent / 'backend/lib' + options["extern_libs"] = {"libdevice": triton.knobs.musa.libdevice_path or str(libdir / 'libdevice.31.bc')} + + compile_info = {} + counter = 0 + + def compile_info_hook(key, repr, fn, compile, is_manual_warmup, already_compiled): + nonlocal compile_info + compile_info = compile + + def cache_hook(*args, **kwargs): + nonlocal counter + counter += 1 + + @triton.jit + def kernel(x): + pass + + triton.knobs.runtime.jit_post_compile_hook = compile_info_hook + triton.knobs.runtime.jit_cache_hook = cache_hook + + # run first without options + kernel[(1, 1, 1)](6) + assert counter == 1 + + # run with options, should lead to new compilation + kernel[(1, 1, 1)](6, **options) + assert counter == 2 + + # run a second time for testing kernel-cache look-up + kernel[(1, 1, 1)](6, **options) + assert counter == 2 + + # check the options are passed on to compile_info correctly + option_key, option_val = next(iter(options.items())) + if option_key == "extern_libs": + # HIPOptions overwrite the extern_libs option, so we skip the test + # passing and specializing options still is tested + if not is_hip(): + assert compile_info[option_key] == tuple(option_val.items()) + else: + assert compile_info[option_key] == option_val + + triton.knobs.runtime.jit_post_compile_hook = None + triton.knobs.runtime.jit_cache_hook = None + + +@pytest.mark.interpreter +def test_pre_run_hooks(device): + + @triton.jit + def add_kernel(a_ptr, n_elements: tl.constexpr): + offsets = tl.arange(0, n_elements) + a = tl.load(a_ptr + offsets) + a += 2 + tl.store(a_ptr + offsets, a) + + def my_hook(*args, **kwargs): + args[0].zero_() + + add_kernel.add_pre_run_hook(my_hook) + + n_elements = 4 + a = torch.ones(n_elements, device=device, dtype=torch.int32) + add_kernel[(1, )](a, n_elements) + assert torch.all(a == 2) + + a = torch.ones(n_elements, device=device, dtype=torch.int32) + add_kernel.run(a, n_elements, grid=(1, ), warmup=False) + assert torch.all(a == 2) diff --git a/third_party/mthreads/python/test/unit/runtime/test_out_of_resources.py b/third_party/mthreads/python/test/unit/runtime/test_out_of_resources.py new file mode 100644 index 0000000000..723785b174 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_out_of_resources.py @@ -0,0 +1,96 @@ +import types + +import pytest +import triton +from triton.compiler import compiler as triton_compiler +from triton.runtime.autotuner import Autotuner + + +class _DummyKernel: + + def __init__(self, error): + self.fn = lambda: None + self._error = error + + def run(self, *args, **kwargs): + raise self._error + + +def _make_compiled_kernel(shared, num_warps=1): + kernel = triton_compiler.CompiledKernel.__new__(triton_compiler.CompiledKernel) + kernel.module = None + kernel.function = None + kernel._run = None + kernel.src = object() + kernel.metadata = types.SimpleNamespace(shared=shared, num_warps=num_warps, tmem_size=None) + kernel.metadata_group = {} + kernel.hash = "dummy_hash" + kernel.name = "dummy_kernel" + kernel.kernel = b"\x00" + return kernel + + +def test_compiled_kernel_raises_out_of_resources(monkeypatch): + + def load_binary(*args, **kwargs): + raise AssertionError("load_binary should not be called") + + active_driver = triton_compiler.driver.active + monkeypatch.setattr(active_driver, "get_current_device", lambda: 0, raising=False) + monkeypatch.setattr(active_driver, "get_current_target", lambda: types.SimpleNamespace(warp_size=32), raising=False) + monkeypatch.setattr(active_driver, "launcher_cls", lambda src, metadata: lambda *args, **kwargs: None, + raising=False) + monkeypatch.setattr(active_driver.utils, "get_device_properties", lambda device: {"max_shared_mem": 0}, + raising=False) + monkeypatch.setattr(active_driver.utils, "load_binary", load_binary, raising=False) + triton_compiler.max_shared_mem.cache_clear() + + kernel = _make_compiled_kernel(shared=1) + with pytest.raises(triton.OutOfResources): + kernel._init_handles() + + +def test_compiled_kernel_loads_within_shared_limit(monkeypatch): + calls = {} + + def load_binary(name, data, shared, device): + calls["args"] = (name, data, shared, device) + return ("mod", "func", 0, 0, 1024) + + active_driver = triton_compiler.driver.active + monkeypatch.setattr(active_driver, "get_current_device", lambda: 0, raising=False) + monkeypatch.setattr(active_driver, "get_current_target", lambda: types.SimpleNamespace(warp_size=32), raising=False) + monkeypatch.setattr(active_driver, "launcher_cls", lambda src, metadata: lambda *args, **kwargs: None, + raising=False) + monkeypatch.setattr(active_driver.utils, "get_device_properties", lambda device: {"max_shared_mem": 1024}, + raising=False) + monkeypatch.setattr(active_driver.utils, "load_binary", load_binary, raising=False) + triton_compiler.max_shared_mem.cache_clear() + + kernel = _make_compiled_kernel(shared=1, num_warps=1) + kernel._init_handles() + assert calls["args"] == ("dummy_kernel", b"\x00", 1, 0) + assert kernel.module == "mod" + assert kernel.function == "func" + + +def test_autotuner_drops_out_of_resources(): + err = triton.OutOfResources(128, 64, "shared memory") + fn = _DummyKernel(err) + + def fake_do_bench(kernel_call, quantiles): + kernel_call() + return [0.0, 0.0, 0.0] + + autotuner = Autotuner( + fn=fn, + arg_names=[], + configs=[triton.Config(kwargs={})], + key=[], + reset_to_zero=None, + restore_value=None, + do_bench=fake_do_bench, + ) + autotuner.nargs = {} + result = autotuner._bench(config=autotuner.configs[0]) + assert result == [float("inf"), float("inf"), float("inf")] diff --git a/third_party/mthreads/python/triton/_C/libtriton/linear_layout.pyi b/third_party/mthreads/python/triton/_C/libtriton/linear_layout.pyi new file mode 100644 index 0000000000..e1b4599dd0 --- /dev/null +++ b/third_party/mthreads/python/triton/_C/libtriton/linear_layout.pyi @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + + +class LinearLayout: + def __init__(self) -> None: ... + + @staticmethod + def identity_1d(size: int, inDim: str, outDim: str) -> LinearLayout: ... + + @staticmethod + def strided_1d( + size: int, stride: int, inDim: str, outDim: str + ) -> LinearLayout: ... + + @staticmethod + def zeros_1d( + size: int, inDim: str, outDim: str, outDimSize: int + ) -> LinearLayout: ... + + @staticmethod + def from_bases( + bases: Sequence[Tuple[str, Sequence[Sequence[int]]]], + out_dim_names: Sequence[str], + out_dim_sizes: Optional[Sequence[int]] = ..., + require_surjective: bool = ..., + ) -> LinearLayout: ... + + def compose(self, other: LinearLayout) -> LinearLayout: ... + + def invert_and_compose(self, other: LinearLayout) -> LinearLayout: ... + + def invert(self) -> LinearLayout: ... + + def pseudoinvert(self) -> LinearLayout: ... + + def is_surjective(self) -> bool: ... + + def is_injective(self) -> bool: ... + + def is_invertible(self) -> bool: ... + + def get_in_dim_names(self) -> List[str]: ... + + def get_out_dim_names(self) -> List[str]: ... + + @property + def bases(self) -> List[Tuple[str, List[List[int]]]]: ... + + @property + def out_dims(self) -> List[Tuple[str, int]]: ... + + @property + def num_in_dims(self) -> int: ... + + @property + def num_out_dims(self) -> int: ... + + def __mul__(self, other: LinearLayout) -> LinearLayout: ... + + def __imul__(self, other: LinearLayout) -> LinearLayout: ... + + def get_shared_view(self, useHWPointOfView: bool) -> str: ... + + def get_distributed_view(self, useHWPointOfView: bool) -> str: ... + + def get_matrix_view(self) -> List[List[int]]: ... + + def apply( + self, inputs: Sequence[Tuple[str, int]] + ) -> List[Tuple[str, int]]: ... + + def __eq__(self, other: object) -> bool: ... + + def __ne__(self, other: object) -> bool: ... + + def __repr__(self) -> str: ... + + def __str__(self) -> str: ... diff --git a/third_party/mthreads/python/triton/__init__.py b/third_party/mthreads/python/triton/__init__.py new file mode 100644 index 0000000000..a1e768c975 --- /dev/null +++ b/third_party/mthreads/python/triton/__init__.py @@ -0,0 +1,83 @@ +"""isort:skip_file""" +__version__ = '3.6.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import constexpr_function, jit +from .runtime._async_compile import AsyncCompileMode, FutureKernel +from .compiler import compile, CompilationError +from .errors import TritonError +from .runtime._allocation import set_allocator + +from . import language +from . import testing +from . import tools + +must_use_result = language.core.must_use_result + +__all__ = [ + "AsyncCompileMode", + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "constexpr_function", + "FutureKernel", + "heuristics", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "max_shared_mem", + "MockTensor", + "must_use_result", + "next_power_of_2", + "OutOfResources", + "reinterpret", + "runtime", + "set_allocator", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +@constexpr_function +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +@constexpr_function +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/mthreads/python/triton/_filecheck.py b/third_party/mthreads/python/triton/_filecheck.py new file mode 100644 index 0000000000..3bb74970a5 --- /dev/null +++ b/third_party/mthreads/python/triton/_filecheck.py @@ -0,0 +1,116 @@ +import functools +import os +import inspect +import subprocess +import tempfile + +import triton +from triton.backends import backends +from triton.compiler import ASTSource, make_backend +from triton.backends.compiler import GPUTarget +from triton.experimental.gluon._runtime import GluonASTSource +from triton.runtime.jit import create_function_from_signature +from triton._C.libtriton import ir + +# ===-----------------------------------------------------------------------===# +# filecheck_test +# ===-----------------------------------------------------------------------===# + + +def _get_stub_target() -> GPUTarget: + backend_name = os.environ.get("TRITON_DEFAULT_BACKEND") + if backend_name is None and len(backends) == 1: + backend_name = next(iter(backends)) + + if backend_name in ("nvidia", "cuda"): + return GPUTarget("cuda", 100, 32) + if backend_name in ("amd", "hip"): + return GPUTarget("hip", "gfx942", 64) + if backend_name in ("mthreads", "musa"): + arch = os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1" + return GPUTarget("musa", arch, 32) + + # Preserve the legacy frontend parser target when no backend is selected. + return GPUTarget("cuda", 100, 32) + + +triton_dir = os.path.dirname(__file__) +filecheck_path = os.path.join(triton_dir, "FileCheck") + + +class MatchError(ValueError): + + def __init__(self, message, module_str): + super().__init__(message) + self.module_str = module_str + + def __str__(self): + return f"{super().__str__()}\n{self.module_str}" + + +def run_filecheck(name, module_str, check_template): + with tempfile.TemporaryDirectory() as tempdir: + temp_module = os.path.join(tempdir, "module") + with open(temp_module, "w") as temp: + temp.write(module_str) + + temp_expected = os.path.join(tempdir, "expected") + with open(temp_expected, "w") as temp: + temp.write(check_template) + + try: + subprocess.check_output( + [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"], + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as error: + decoded = error.output.decode('unicode_escape') + raise ValueError(decoded) + + +def run_parser(kernel_fn, args=(), kwargs=None, target=None): + if kwargs is None: + kwargs = {} + if target is None: + target = _get_stub_target() + if "sanitize_overflow" not in kwargs: + kwargs = dict(kwargs) + kwargs["sanitize_overflow"] = False + backend = make_backend(target) + binder = create_function_from_signature( + kernel_fn.signature, + kernel_fn.params, + backend, + ) + + bound_args, specialization, options = binder(*args, **kwargs) + options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options) + source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource + src = source_cls(kernel_fn, signature, constexprs, attrs) + + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + codegen_fns = backend.get_codegen_implementation(options) + module_map = backend.get_module_map() + module = src.make_ir(target, options, codegen_fns, module_map, context) + return module + + +def run_filecheck_test(kernel_fn): + assert isinstance(kernel_fn, triton.runtime.JITFunction) + check_template = inspect.getsource(kernel_fn.fn) + if check_template is None: + raise ValueError("kernel function must have a docstring with FileCheck template") + mlir_module = run_parser(kernel_fn) + + run_filecheck("placeholder", mlir_module.str_nodebug(), check_template) + + +def filecheck_test(fn): + + @functools.wraps(fn) + def test_fn(): + run_filecheck_test(fn) + + return test_fn diff --git a/third_party/mthreads/python/triton/_internal_testing.py b/third_party/mthreads/python/triton/_internal_testing.py new file mode 100644 index 0000000000..111841250f --- /dev/null +++ b/third_party/mthreads/python/triton/_internal_testing.py @@ -0,0 +1,293 @@ +import os +import re +import numpy as np +import torch +import triton +import triton.language as tl +from triton import knobs +from typing import Optional, Set, Union +import pytest + +from numpy.random import RandomState +from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +integral_dtypes = int_dtypes + uint_dtypes +float_dtypes = ['float16', 'float32', 'float64'] +float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16'] +dtypes = integral_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] +tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def get_current_target(): + if is_interpreter(): + return None + return triton.runtime.driver.active.get_current_target() + + +def is_cuda(): + target = get_current_target() + return False if target is None else target.backend == "cuda" + + +def is_ampere_or_newer(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 8 + + +def is_blackwell(): + return is_cuda() and torch.cuda.get_device_capability()[0] in [10, 11] + + +def is_blackwell_ultra(): + return is_cuda() and torch.cuda.get_device_capability()[0:2] == (10, 3) + + +def is_hopper_or_newer(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 9 + + +def is_sm12x(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 12 + + +def is_hip(): + target = get_current_target() + return False if target is None else target.backend == "hip" + + +def is_musa(): + target = get_current_target() + return False if target is None else target.backend == "musa" + + +def is_hip_cdna2(): + target = get_current_target() + return target is not None and target.backend == 'hip' and target.arch == 'gfx90a' + + +def is_hip_cdna3(): + target = get_current_target() + return target is not None and target.backend == 'hip' and target.arch == 'gfx942' + + +def is_hip_cdna4(): + target = get_current_target() + return target is not None and target.backend == 'hip' and target.arch == 'gfx950' + + +def is_hip_rdna3(): + target = get_current_target() + return target is not None and target.backend == 'hip' and 'gfx11' in target.arch + + +def is_hip_rdna4(): + target = get_current_target() + return target is not None and target.backend == 'hip' and 'gfx12' in target.arch + + +def is_hip_gfx1250(): + target = get_current_target() + return target is not None and target.backend == 'hip' and 'gfx1250' in target.arch + + +def is_hip_cdna(): + return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4() + + +def get_hip_lds_size(): + return 163840 if is_hip_cdna4() else 65536 + + +def is_musa(): + target = get_current_target() + return False if target is None else target.backend == "musa" + + +def is_musa_ph1(): + return is_musa() and torch.musa.get_device_capability() == (3, 1) + + +def is_xpu(): + target = get_current_target() + return False if target is None else target.backend == "xpu" + + +def get_arch(): + target = get_current_target() + return "" if target is None else str(target.arch) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + if is_musa(): + return torch.tensor(x, device='cpu').bfloat16().to(device) + else: + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def str_to_triton_dtype(x: str) -> tl.dtype: + return tl.str_to_ty(type_canonicalisation_dict[x], None) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def supports_tma(byval_only=False): + if is_interpreter(): + return True + if not is_cuda(): + return False + cuda_version = knobs.nvidia.ptxas.version + min_cuda_version = (12, 0) if byval_only else (12, 3) + cuda_version_tuple = tuple(map(int, cuda_version.split("."))) + assert len(cuda_version_tuple) == 2, cuda_version_tuple + return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version + + +def supports_ws(): + if is_interpreter(): + return True + if not is_cuda(): + return False + return torch.cuda.get_device_capability()[0] >= 9 + + +def tma_skip_msg(byval_only=False): + if byval_only: + return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)" + else: + return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)" + + +requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg()) + + +def default_alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + +def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor: + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t + + +def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None): + from triton import knobs + + if skipped_attr is None: + skipped_attr = set() + + monkeypatch = pytest.MonkeyPatch() + + knobs_map = { + name: knobset + for name, knobset in knobs.__dict__.items() + if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr + } + + # We store which variables we need to unset below in finally because + # monkeypatch doesn't appear to reset variables that were never set + # before the monkeypatch.delenv call below. + env_to_unset = [] + prev_propagate_env = knobs.propagate_env + + def fresh_function(): + nonlocal env_to_unset + for name, knobset in knobs_map.items(): + setattr(knobs, name, knobset.copy().reset()) + for knob in knobset.knob_descriptors.values(): + if knob.key in os.environ: + monkeypatch.delenv(knob.key, raising=False) + else: + env_to_unset.append(knob.key) + knobs.propagate_env = True + return knobs + + def reset_function(): + for name, knobset in knobs_map.items(): + setattr(knobs, name, knobset) + # `undo` should be placed before `del os.environ` + # Otherwise, it may restore environment variables that monkeypatch deleted + monkeypatch.undo() + for k in env_to_unset: + if k in os.environ: + del os.environ[k] + knobs.propagate_env = prev_propagate_env + + return fresh_function, reset_function diff --git a/third_party/mthreads/python/triton/_utils.py b/third_party/mthreads/python/triton/_utils.py new file mode 100644 index 0000000000..aac0e676fe --- /dev/null +++ b/third_party/mthreads/python/triton/_utils.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from functools import reduce +from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict + +if TYPE_CHECKING: + from .language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any: + return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index] + + +def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any): + from .language import core + assert len(path) != 0 + prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1]) + assert isinstance(prev, core.tuple) + prev._setitem(path[-1], val) + + +def is_iterable(x): + from .language import core + return isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + + +def apply_with_path(value: Any, fn: Callable[[ObjPath, Any], None], _path=None) -> None: + if _path is None: + _path = () + + if is_iterable(value): + for idx, item in enumerate(value): + apply_with_path(item, fn, _path=(*_path, idx)) + else: + fn(_path, value) + + +def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]: + # We need to use dict so that ordering is maintained, while set doesn't guarantee order + ret: dict[ObjPath, None] = {} + + def _impl(path: tuple[int, ...], current: Any): + if is_iterable(current): + for idx, item in enumerate(current): + _impl((*path, idx), item) + elif pred(path, current): + ret[path] = None + + _impl((), iterable) + + return list(ret.keys()) + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel + + +type_canonicalisation_dict = { + # we canonicalise all bools to be unsigned: + "bool": "u1", + "int1": "u1", + "uint1": "u1", + "i1": "u1", + # floating-point dtypes: + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "half": "fp16", + "float16": "fp16", + "bfloat16": "bf16", + "float": "fp32", + "float32": "fp32", + "double": "fp64", + "float64": "fp64", + # signed integers: + "int8": "i8", + "int16": "i16", + "int": "i32", + "int32": "i32", + "int64": "i64", + # unsigned integers: + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + "void": "void", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +def canonicalize_dtype(dtype): + dtype_str = str(dtype).split(".")[-1] + return type_canonicalisation_dict[dtype_str] + + +def canonicalize_ptr_dtype(dtype, is_const): + return f"{'*k' if is_const else '*'}{canonicalize_dtype(dtype)}" + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + +for k, v in type_canonicalisation_dict.items(): + BITWIDTH_DICT[k] = BITWIDTH_DICT[v] + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] + + +def is_namedtuple(val): + return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") + + +def _tuple_create(arg, contents): + # NamedTuples and tuples have different construction semantics. NamedTuple + # has a constructor that takes individual arguments, while tuple takes an + # iterable. Both have type "tuple" making it difficult to distinguish + # between them, but only NamedTuple has "_fields" and apparently this is how + # everyone does the check. + return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents) diff --git a/third_party/mthreads/python/triton/backends/__init__.py b/third_party/mthreads/python/triton/backends/__init__.py new file mode 100644 index 0000000000..97092379e0 --- /dev/null +++ b/third_party/mthreads/python/triton/backends/__init__.py @@ -0,0 +1,66 @@ +import importlib +import os +import inspect +import sys +from dataclasses import dataclass +from typing import Type, TypeVar, Union +from types import ModuleType +from .driver import DriverBase +from .compiler import BaseBackend + +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points +else: + from importlib_metadata import entry_points + +T = TypeVar("T", bound=Union[BaseBackend, DriverBase]) + + +def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]: + ret: list[Type[T]] = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: Type[BaseBackend] + driver: Type[DriverBase] + + +def _discover_backends() -> dict[str, Backend]: + backends = dict() + # Fast path: optionally skip entry point discovery (which can be slow) and + # discover only in-tree backends under the `triton.backends` namespace. + skip_entrypoints_env = os.environ.get("TRITON_BACKENDS_IN_TREE", "") + + if skip_entrypoints_env == "1": + root = os.path.dirname(__file__) + for name in os.listdir(root): + if not os.path.isdir(os.path.join(root, name)): + continue + if name.startswith('__'): + continue + compiler = importlib.import_module(f"triton.backends.{name}.compiler") + driver = importlib.import_module(f"triton.backends.{name}.driver") + backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + _find_concrete_subclasses(driver, DriverBase)) + return backends + + # Default path: discover via entry points for out-of-tree/downstream plugins. + for ep in entry_points().select(group="triton.backends"): + compiler = importlib.import_module(f"{ep.value}.compiler") + driver = importlib.import_module(f"{ep.value}.driver") + backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore + _find_concrete_subclasses(driver, DriverBase)) # type: ignore + return backends + + +backends: dict[str, Backend] = _discover_backends() diff --git a/third_party/mthreads/python/triton/backends/compiler.py b/third_party/mthreads/python/triton/backends/compiler.py new file mode 100644 index 0000000000..10754e7157 --- /dev/null +++ b/third_party/mthreads/python/triton/backends/compiler.py @@ -0,0 +1,92 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Union +from types import ModuleType + + +@dataclass(frozen=True) +class GPUTarget(object): + # Target backend, e.g., cuda, hip + backend: str + # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) + arch: Union[int, str] + warp_size: int + + +class Language(Enum): + """The input language being compiled by the backend.""" + TRITON = 0 + GLUON = 1 + + +class BaseBackend(metaclass=ABCMeta): + supports_native_tensor_specialization = True + + def __init__(self, target: GPUTarget) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + @abstractmethod + def supports_target(target: GPUTarget): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError + + @abstractmethod + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations + """ + raise NotImplementedError + + @staticmethod + def parse_attr(desc): + assert isinstance(desc, str) + ret = [] + if "D" in desc: + ret += [["tt.divisibility", 16]] + return ret + + @staticmethod + def get_int_specialization(arg, **kwargs): + if arg % 16 == 0 and kwargs.get("align", False): + return "D" + return "" + + @staticmethod + def get_tensor_specialization(arg, **kwargs): + if arg.data_ptr() % 16 == 0 and kwargs.get("align", False): + return "D" + return "" diff --git a/third_party/mthreads/python/triton/backends/driver.py b/third_party/mthreads/python/triton/backends/driver.py new file mode 100644 index 0000000000..13a658b47e --- /dev/null +++ b/third_party/mthreads/python/triton/backends/driver.py @@ -0,0 +1,66 @@ +from abc import ABCMeta, abstractmethod +from typing import Callable, List, Protocol, Sequence + + +class Benchmarker(Protocol): + + def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: + pass + + +class DriverBase(metaclass=ABCMeta): + + @classmethod + @abstractmethod + def is_active(self): + pass + + @abstractmethod + def map_python_to_cpp_type(self, ty: str) -> str: + """ + Converts a Triton type string to its corresponding C++ type string for this backend. + + Args: + ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'. + + Returns: + str: The C++ type string. + """ + pass + + @abstractmethod + def get_current_target(self): + pass + + @abstractmethod + def get_active_torch_device(self): + pass + + @abstractmethod + def get_benchmarker(self) -> Benchmarker: + """ + Return the benchmarking function that this backend should use by default. + """ + raise NotImplementedError + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + self.get_device_capability = torch.cuda.get_device_capability + try: + from torch._C import _cuda_getCurrentRawStream + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/third_party/mthreads/python/triton/compiler/__init__.py b/third_party/mthreads/python/triton/compiler/__init__.py new file mode 100644 index 0000000000..89dc58907f --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/__init__.py @@ -0,0 +1,7 @@ +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict, get_cache_key, max_shared_mem +from .errors import CompilationError + +__all__ = [ + "compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict", + "get_cache_key", "max_shared_mem" +] diff --git a/third_party/mthreads/python/triton/compiler/code_generator.py b/third_party/mthreads/python/triton/compiler/code_generator.py new file mode 100644 index 0000000000..df09a1cfc0 --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/code_generator.py @@ -0,0 +1,1670 @@ +import ast +import builtins +import contextlib +import copy +import functools +import inspect +import re +import warnings +import textwrap +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List + +from .. import knobs, language +from .._C.libtriton import ir, gluon_ir +from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple +from ..language.core import _unwrap_if_constexpr, base_value, base_type +# ideally we wouldn't need any runtime component +from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction +from .._utils import apply_with_path, set_iterable_path, is_namedtuple + +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) + + +def check_identifier_legality(name, type): + pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + if not re.match(pattern, name): + raise CompilationError(f"invalid {type} identifier: {name}", name) + return name + + +def mangle_fn(name, arg_tys, caller_context): + # doesn't mangle ret type, which must be a function of arg tys + mangled_args = '_'.join([ty.mangle() for ty in arg_tys]) + mangled_args = mangled_args.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_args = mangled_args.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_args}' + if caller_context is not None: + ret += caller_context.mangle() + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, base_value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable)) + + +def _is_non_scalar_tensor(o: Any) -> bool: + return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and _is_non_scalar_tensor(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _check(cond, msg_fn, category=TypeError): + if not cond: + raise category(msg_fn()) + + +def _apply_to_tuple_values(value, fn): + if is_namedtuple(type(value)): + fields = value._fields + elif isinstance(value, language.tuple): + fields = value.type.fields + else: + assert False, f"Unsupported type {type(value)}" + + vals = [fn(v) for v in value] + vals = [constexpr(v) if v is None else v for v in vals] + types = [v.type for v in vals] + return language.tuple(vals, language.tuple_type(types, fields)) + + +def flatten_values_to_ir(values: Iterable[base_value]): + handles = [] + for v in values: + v._flatten_ir(handles) + return handles + + +def unflatten_ir_values(handles: List[ir.value], types: List[base_type]): + cursor = 0 + for ty in types: + value, cursor = ty._unflatten_ir(handles, cursor) + yield value + assert cursor == len(handles) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = dict(self.generator.lscope) + self.prev_defs = dict(self.generator.local_defs) + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # No need to check within the function as it won't cause an early return. + # If the function itself has unstructured control flow we may not be able to inline it causing poor performance, + # we should check for this and emit a warning. + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class ASTFunction: + + def __init__(self, ret_types, arg_types, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.attrs = attrs + + def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]: + ir_types = [] + for ty in types: + if ty is None: + continue + ty._flatten_ir_types(builder, ir_types) + return ir_types + + def return_types_ir(self, builder: ir.builder) -> List[ir.type]: + return self.flatten_ir_types(builder, self.ret_types) + + def serialize(self, builder: ir.builder): + # > build mlir function type + arg_types_ir = self.flatten_ir_types(builder, self.arg_types) + ret_types_ir = self.return_types_ir(builder) + return builder.get_function_ty(arg_types_ir, ret_types_ir) + + def deserialize(self, fn): + # create "template" + def make_template(ty): + if isinstance(ty, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in ty], ty) + return language.constexpr(None) + + vals = make_template(self.arg_types) + handles = [fn.args(i) for i in range(fn.get_num_args())] + cursor = 0 + + def build_value(path, ty): + nonlocal cursor, handles + # > set attributes + attr_specs = self.attrs.get(path, []) + for attr_name, attr_val in attr_specs: + fn.set_arg_attr(cursor, attr_name, attr_val) + # > build frontend value + val, cursor = ty._unflatten_ir(handles, cursor) + set_iterable_path(vals, path, val) + + apply_with_path(self.arg_types, build_value) + return vals + + +@dataclass(frozen=True) +class BoundJITMethod: + __self__: base_value + __func__: JITFunction + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns, + module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.is_gluon = is_gluon + if is_gluon: + from triton.experimental.gluon.language._semantic import GluonSemantic + self.builder = gluon_ir.GluonOpBuilder(context) + self.semantic = GluonSemantic(self.builder) + else: + from triton.language.semantic import TritonSemantic + self.builder = ir.builder(context) + self.semantic = TritonSemantic(self.builder) + + self.name_loc_as_prefix = None + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.return_vals: List[base_value | None] = [] + self.return_ips: List[Tuple[ir.InsertPoint, ir.Loc]] = [] + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.jit_fn = jit_fn + # TODO: we currently generate illegal names for non-kernel functions involving constexprs! + if is_kernel: + function_name = function_name[function_name.rfind('.') + 1:] + function_name = check_identifier_legality(function_name, "function") + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.caller_context = caller_context + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + self.uses_tme = False + + builtin_namespace: Dict[str, Any] = { + _.__name__: _ + for _ in (len, list, range, float, int, isinstance, getattr, hasattr) + } + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.core.builtin_min), + ('max', language.core.builtin_max), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, + name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITCallable), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__triton_aggregate__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), # + isinstance(val, language.dtype), # + is_namedtuple(val), + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + knobs.compilation.allow_non_constexpr_globals, + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from + annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + @contextlib.contextmanager + def _name_loc_prefix(self, prefix): + self.name_loc_as_prefix = prefix + yield + self.name_loc_as_prefix = None + + def _maybe_set_loc_to_name(self, val, name): + if isinstance(val, (ir.value, ir.block_argument)): + val.set_loc(self.builder.create_name_loc(name, val.get_loc())) + elif _is_triton_value(val): + handles = [] + val._flatten_ir(handles) + for handle in handles: + handle.set_loc(self.builder.create_name_loc(name, handle.get_loc())) + + def set_value(self, name: str, value: Union[base_value, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + def _find_carries(self, node, liveins, ignore: set[str] = set()): + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) has changed value within the loop, then it's + # a loop-carried variable. (The new and old value must be of the + # same type) + init_tys = [] + init_handles = [] + names = [] + + for name, live_val in liveins.items(): + if name in ignore: + continue + + if _is_triton_value(live_val): + loop_val = self.lscope[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + live_handles = flatten_values_to_ir([live_val]) + loop_handles = flatten_values_to_ir([loop_val]) + if live_handles != loop_handles: + names.append(name) + init_tys.append(live_val.type) + init_handles.extend(live_handles) + else: + assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value' + + # reset local scope to not pick up local defs from the dry run. + self.lscope = liveins.copy() + self.local_defs = {} + + return names, init_handles, init_tys + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = language.tuple([self.visit(elt) for elt in node.elts]) + return elts + + def visit_ListComp(self, node: ast.ListComp): + if len(node.generators) != 1: + raise ValueError("nested comprehensions are not supported") + + comp = node.generators[0] + iter = self.visit(comp.iter) + if not isinstance(iter, tl_tuple): + raise NotImplementedError("only tuple comprehensions are supported") + + results = [] + for item in iter: + self.set_value(comp.target.id, item) + results.append(self.visit(node.elt)) + return tl_tuple(results) + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + if ret_value is None: + ret_value = language.constexpr(None) + self.return_vals.append(ret_value) + self.return_ips.append(self._get_insertion_point_and_loc()) + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def decide_return_type(self): + assert len(self.return_vals) == len(self.return_ips) + if not self.return_vals: + return language.constexpr_type(None) + + tl = language.core + + def error_msg(a, b): + err = f"Return type mismatch: {a} and {b}. " + err += f"Note all return types were: {return_types}" + return err + + def common_type(a, b): + if isinstance(a, tl.tuple_type): + _check(isinstance(b, tl.tuple_type), lambda: error_msg(a, b)) + _check(a.fields == b.fields, lambda: error_msg(a, b)) + return tl.tuple_type([common_type(ai, bi) for ai, bi in zip(a, b)], fields=a.fields) + if isinstance(a, tl.constexpr_type): + if a == b: + return a + a = self.semantic.to_tensor_type(a) + b = self.semantic.to_tensor_type(b) + elif isinstance(b, tl.constexpr_type): + a = self.semantic.to_tensor_type(a) + b = self.semantic.to_tensor_type(b) + _check(a == b, lambda: error_msg(a, b)) + return a + + return_types = [x.type for x in self.return_vals] + return functools.reduce(common_type, return_types) + + def cast_to(self, value, ty): + if value.type == ty: + return value + + tl = language.core + if isinstance(value, tl.tuple): + assert isinstance(ty, tl.tuple_type) + return tl.tuple( + [self.cast_to(v, t) for v, t in zip(value.values, ty.types)], + ty, + ) + if isinstance(value, tl.constexpr): + if isinstance(ty, tl.constexpr_type): + _check(value.type == ty, lambda: f"Return type mismatch {value.type} and {ty}") + return value + return self.semantic.scalar_constant(value.value, ty) + _check(value.type == ty, lambda: f"Return type mismatch {value.type} and {ty}") + return value + + def handle_returns(self): + return_type = self.decide_return_type() + ip, loc = self._get_insertion_point_and_loc() + + assert len(self.return_vals) == len(self.return_ips) + for ret, ret_ip in zip(self.return_vals, self.return_ips): + self._set_insertion_point_and_loc(*ret_ip) + assert not self.builder.get_insertion_block().has_terminator() + ret = self.cast_to(ret, return_type) + ret_handles = flatten_values_to_ir([ret]) + self.builder.ret(ret_handles) + + self._set_insertion_point_and_loc(ip, loc) + self.ret_type = return_type + assert not self.builder.get_insertion_block().has_terminator() + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = list(self.ret_type.types) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) + self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)]) + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = self.prototype.deserialize(self.fn) + if self.caller_context is not None: + self.caller_context.initialize_callee(self.fn, self.builder) + # bind arguments to symbols + for arg_name, arg_value in zip(arg_names, arg_values): + self._maybe_set_loc_to_name(arg_value, arg_name) + self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + self.handle_returns() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + param = next(p for p in self.jit_fn.params if p.name == node.arg) + if param.is_constexpr and (param.do_not_specialize or param.do_not_specialize_on_alignment): + raise CompilationError( + self.jit_fn.src, node, + f"{node.arg} marked as constexpr and listed in do_not_specialize/do_not_specialize_on_alignment. " + "Remove constexpr designation to skip specialization.") + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def assignTarget(self, target, value): + assert isinstance(target.ctx, ast.Store) + if isinstance(target, ast.Subscript): + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + for i, target in enumerate(target.elts): + self.assignTarget(target, value.values[i]) + return + if isinstance(target, ast.Attribute): + raise NotImplementedError("Attribute assignment is not supported in triton") + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + + def visit_Assign(self, node): + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return _apply_to_tuple_values(value, _sanitize_value) + native_nontensor_types = (language.dtype, language.tuple) + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = self.semantic.to_tensor(value) + return value + + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + target = targets[0] + if isinstance(target, ast.Name): + with self._name_loc_prefix(target.id): + values = _sanitize_value(self.visit(node.value)) + else: + values = _sanitize_value(self.visit(node.value)) + self.assignTarget(target, values) + + def visit_AugAssign(self, node): + lhs = copy.deepcopy(node.target) + lhs.ctx = ast.Load() + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + for x in ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']: + if hasattr(node, x): + y = getattr(node, x) + setattr(rhs, x, y) + setattr(assign, x, y) + self.visit(assign) + return self.visit(lhs) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return language.tuple(args) + + def _apply_binary_method(self, node, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _semantic=self.semantic) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic) + if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr): + lhs = constexpr(lhs) + if isinstance(lhs, constexpr): + fn = getattr(lhs, method_name) + else: + fn = self.get_Attribute(lhs, method_name) + return self.call_Function(node, fn, [rhs], {}) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(node, method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + then_vals = self.lscope.copy() + # else block + else_defs = {} + else_vals = liveins.copy() + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + else_vals = self.lscope.copy() + + # update block arguments + names = [] + # variables in livein whose value is updated in `if` + for name, value in liveins.items(): + # livein variable changed value in either then or else + if not _is_triton_value(value): + continue + then_handles = flatten_values_to_ir([then_vals[name]]) + else_handles = flatten_values_to_ir([else_vals[name]]) + if then_handles == else_handles: + continue + names.append(name) + then_defs[name] = then_vals[name] + else_defs[name] = else_vals[name] + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + type_equal = type(defs[name]) == type(value) # noqa: E721 + assert type_equal and defs[name].type == value.type, \ + f'initial value for `{name}` is of type {value}, '\ + f'but the {block_name} block redefines it as {defs[name]}' + + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_val = then_defs[name] + then_ty = then_val.type + else_val = else_defs[name] + else_ty = else_val.type + type_equal = type(then_val) == type(else_val) # noqa: E721 + assert type_equal and then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + + return then_defs, else_defs, then_block, else_block, names + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + self.builder.create_branch(endif_block, then_handles) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + self.builder.create_branch(endif_block, else_handles) + assert len(then_handles) == len(else_handles) + for then_h, else_h in zip(then_handles, else_handles): + ty = then_h.get_type() + assert ty == else_h.get_type() + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + res_handles = [endif_block.arg(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + for name, val in zip(names, then_handles): + self._maybe_set_loc_to_name(val, name) + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op(then_handles) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + for name, val in zip(names, else_handles): + self._maybe_set_loc_to_name(val, name) + self.builder.create_yield_op(else_handles) + # update values + res_handles = [if_op.get_result(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + if _is_non_scalar_tensor(cond): + raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous") + if cond.type.is_block(): + warnings.warn( + "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead" + % ast.unparse(node.test)) + cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self) + cond = cond.to(language.int1, _semantic=self.semantic) + if ContainsReturnChecker(self.gscope).visit(node): + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton.") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _semantic=self.semantic) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = self.semantic.to_tensor(self.visit(node.body)) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = self.semantic.to_tensor(self.visit(node.orelse)) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_With(self, node): + # Lower `with` statements by constructing context managers and calling their enter/exit hooks + # Instantiate each context manager with builder injection + cm_list = [] + for item in node.items: + call = item.context_expr + fn = self.visit(call.func) + args = [self.visit(arg) for arg in call.args] + kws = dict(self.visit(kw) for kw in call.keywords) + cm = fn(*args, _semantic=self.semantic, **kws) + cm_list.append(cm) + for cm, item in zip(cm_list, node.items): + res = cm.__enter__() + if item.optional_vars is not None: + var_name = self.visit(item.optional_vars) + self.set_value(var_name, res) + if ContainsReturnChecker(self.gscope).visit(node): + raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ") + self.visit_compound_statement(node.body) + for cm in reversed(cm_list): + cm.__exit__(None, None, None) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(node, method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_semantic=self.semantic) + try: + return getattr(operand, fn)() + except AttributeError: + if fn == "__not__": + return constexpr(not operand) + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop' + assert type(loop_val) is type(live_val), ( + f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}') + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + names, init_handles, init_fe_tys = self._find_carries(node, liveins) + + init_tys = [h.get_type() for h in init_handles] + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op(init_tys, init_handles) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys) + self.builder.set_insertion_point_to_start(before_block) + block_args = [before_block.arg(i) for i in range(len(init_handles))] + condition_args = unflatten_ir_values(block_args, init_fe_tys) + for name, val in zip(names, condition_args): + self.lscope[name] = val + self.local_defs[name] = val + self._maybe_set_loc_to_name(val, name) + cond = self.visit(node.test) + if isinstance(cond, language.condition): + if cond.disable_licm: + while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr()) + cond = cond.condition + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, block_args) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + body_handles = [after_block.arg(i) for i in range(len(init_handles))] + body_args = unflatten_ir_values(body_handles, init_fe_tys) + for name, val in zip(names, body_args): + self.lscope[name] = val + self.local_defs[name] = val + self._maybe_set_loc_to_name(val, name) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + + yield_handles = flatten_values_to_ir(self.lscope[name] for name in names) + self.builder.create_yield_op(yield_handles) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + result_handles = [while_op.get_result(i) for i in range(len(init_handles))] + result_vals = unflatten_ir_values(result_handles, init_fe_tys) + for name, new_def in zip(names, result_vals): + self.lscope[name] = new_def + self.local_defs[name] = new_def + self._maybe_set_loc_to_name(new_def, name) + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript_Load(self, node): + assert isinstance(node.ctx, ast.Load) + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_value(lhs): + return self.call_Method(node, lhs.__getitem__, lhs, [slices], {}) + return lhs[slices] + + def visit_Subscript_Store(self, node, value): + raise NotImplementedError("__setitem__ is not supported in triton") + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + disallow_acc_multi_buffer = False + flatten = False + warp_specialize = False + disable_licm = False + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer + flatten = iterator.flatten + warp_specialize = iterator.warp_specialize + disable_licm = iterator.disable_licm + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Constant(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Constant(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = self.semantic.to_tensor(lb) + ub = self.semantic.to_tensor(ub) + step = self.semantic.to_tensor(step) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + if _is_non_scalar_tensor(lb): + raise TypeError(f"For lower bound must be a scalar, got {lb.type}") + if _is_non_scalar_tensor(ub): + raise TypeError(f"For upper bound must be a scalar, got {ub.type}") + if _is_non_scalar_tensor(step): + raise TypeError(f"For step must be a scalar, got {step.type}") + iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv_placeholder = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv_placeholder, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + names, init_handles, init_tys = self._find_carries(node, liveins, ignore={node.target.id}) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, init_handles) + if _unwrap_if_constexpr(num_stages) is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if _unwrap_if_constexpr(loop_unroll_factor) is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if disallow_acc_multi_buffer: + for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr()) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) + if warp_specialize: + for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) + if disable_licm: + for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr()) + + self.scf_stack.append(node) + for_op_body = for_op.get_body(0) + self.builder.set_insertion_point_to_start(for_op_body) + block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))] + block_args = unflatten_ir_values(block_handles, init_tys) + for name, val in zip(names, block_args): + self._maybe_set_loc_to_name(val, name) + self.set_value(name, val) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yield_handles = flatten_values_to_ir(self.lscope[name] for name in names) + + # create YieldOp + if len(yield_handles) > 0: + self.builder.create_yield_op(yield_handles) + for_op_region = for_op_body.get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op_body) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + iv_placeholder.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + self._maybe_set_loc_to_name(iv, node.target.id) + + # update lscope & local_defs (ForOp defines new values) + result_handles = [for_op.get_result(i) for i in range(len(init_handles))] + result_values = unflatten_ir_values(result_handles, init_tys) + for name, val in zip(names, result_values): + self.set_value(name, val) + self._maybe_set_loc_to_name(val, name) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return language.slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _semantic=self.semantic) + + def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None): + bound_args = fn.signature.bind(*args, **kwargs) + bound_args.apply_defaults() + args = bound_args.arguments + args = [args[name] for name in fn.arg_names] + for i, arg in enumerate(args): + if not isinstance(arg, base_value) or isinstance(arg, JITCallable): + args[i] = language.core.constexpr(arg) + # mangle + caller_context = caller_context or self.caller_context + arg_types = [arg.type for arg in args] + fn_name = mangle_fn(get_full_name(fn), arg_types, caller_context) + # generate function def if necessary + if not self.module.has_function(fn_name): + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + prototype = ASTFunction([], arg_types, dict()) + backend_supports_noinline = getattr(self.builder.options, "supports_noinline", True) + generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline and backend_supports_noinline, file_name=file_name, + begin_line=begin_line, options=self.builder.options, + codegen_fns=self.builder.codegen_fns, module_map=self.builder.module_map, + caller_context=caller_context, is_gluon=self.is_gluon) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + if knobs.compilation.front_end_debugging: + raise + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + args_val = flatten_values_to_ir(args) + call_op = self.builder.call(symbol, args_val) + handles = [call_op.get_result(i) for i in range(call_op.get_num_results())] + return next(unflatten_ir_values(handles, [callee_ret_type])) + + def call_Function(self, node, fn, args, kws): + if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)): + args.insert(0, fn.__self__) + fn = fn.__func__ + + mur = getattr(fn, '_must_use_result', False) + if mur and getattr(node, '_is_unused', False): + error_message = ["The result of %s is not being used." % ast.unparse(node.func)] + if isinstance(mur, str): + error_message.append(mur) + raise CompilationError(self.jit_fn.src, node, " ".join(error_message)) + + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance( + fn, ConstexprFunction): + extra_kwargs = dict() + + sig = getattr(fn, "signature", None) + if isinstance(fn, ConstexprFunction): + extra_kwargs["_semantic"] = self.semantic + else: + if sig is None: + sig = inspect.signature(fn) + if '_semantic' in sig.parameters: + extra_kwargs["_semantic"] = self.semantic + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret + except Exception as e: + if knobs.compilation.front_end_debugging: + raise + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, str(e)) from e + + if fn in self.builtin_namespace.values() or (hasattr(fn, '__self__') and not _is_triton_value(fn.__self__)): + args = map(_unwrap_if_constexpr, args) + ret = fn(*args, **kws) + + def wrap_constexpr(x): + if _is_triton_value(x): + return x + return constexpr(x) + + if isinstance(ret, (builtins.tuple, language.tuple)): + return _apply_to_tuple_values(ret, wrap_constexpr) + return wrap_constexpr(ret) + + def call_Method(self, node, fn, fn_self, args, kws): + if isinstance(fn, JITFunction): + args.insert(0, fn_self) + return self.call_Function(node, fn, args, kws) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + tme_descriptor_fns = ( + getattr(language.core, "load_tensor_descriptor", None), + getattr(language.core, "store_tensor_descriptor", None), + getattr(language.core, "_experimental_descriptor_load", None), + getattr(language.core, "_experimental_descriptor_store", None), + ) + if fn in tme_descriptor_fns: + self.uses_tme = True + if not isinstance(fn, BoundJITMethod): + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [] + for arg in node.args: + if isinstance(arg, ast.Starred): + arg = self.visit(arg.value) + assert isinstance(arg, language.core.tuple) + args.extend(arg.values) + else: + args.append(self.visit(arg)) + + return self.call_Function(node, fn, args, kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + + nontrivial_values = [] + + for subnode in node.values: + # we visit the values in order, executing their side-effects + # and possibly early-exiting: + value = self.visit(subnode) + if not _is_triton_tensor(value): + # this is a constexpr, so we might be able to short-circuit: + bv = bool(value) + if (bv is False) and (method_name == "logical_and"): + # value is falsey so return that: + return value + if (bv is True) and (method_name == "logical_or"): + # value is truthy so return that: + return value + # otherwise, our constexpr has no effect on the output of the + # expression so we do not append it to nontrivial_values. + else: + if value.type.is_block(): + lineno = getattr(node, "lineno", None) + if lineno is not None: + lineno += self.begin_line + warnings.warn_explicit( + "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead", + category=UserWarning, + filename=self.file_name, + lineno=lineno, + source=ast.unparse(node), + ) + # not a constexpr so we must append it: + nontrivial_values.append(value) + + if len(nontrivial_values) == 0: + # the semantics of a disjunction of falsey values or conjunction + # of truthy values is to return the final value: + nontrivial_values.append(value) + + while len(nontrivial_values) >= 2: + rhs = nontrivial_values.pop() + lhs = nontrivial_values.pop() + res = self._apply_binary_method(node, method_name, lhs, rhs) + nontrivial_values.append(res) + + assert len(nontrivial_values) == 1 + return nontrivial_values[0] + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + def get_Attribute(self, lhs, attr): + if _is_triton_tensor(lhs) and attr == "T": + return self.semantic.permute(lhs, (1, 0)) + # NOTE: special case ".value" for BC + if isinstance(lhs, constexpr) and attr not in ("value", "type"): + lhs = lhs.value + attr = getattr(lhs, attr) + if _is_triton_value(lhs) and isinstance(attr, JITFunction): + return BoundJITMethod(lhs, attr) + return attr + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if isinstance(lhs, ModuleType): + # follow module_map until reaching fixed-point: + while (name := lhs.__name__) in self.builder.module_map: + lhs = self.builder.module_map[name] + if lhs.__name__ == name: + break + return self.get_Attribute(lhs, node.attr) + + def visit_Expr(self, node): + node.value._is_unused = True + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + if self.name_loc_as_prefix is not None: + self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc)) + else: + self.builder.set_loc(here_loc) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + if knobs.compilation.front_end_debugging: + raise + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + from ..experimental.gluon import language as ttgl + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + ttgl.static_assert: execute_static_assert, + ttgl.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None): + arg_types = [None] * len(fn.arg_names) + + for k, v in src.signature.items(): + idx = fn.arg_names.index(k) + arg_types[idx] = str_to_ty(v, None) + + def apply_constexpr_types(argument, indices, value): + index = indices.pop() + if len(indices) == 0: + if isinstance(argument, list): + argument[index] = constexpr(value).type + else: + argument.types[index] = constexpr(value).type + else: + apply_constexpr_types(argument[index], indices, value) + + for path, value in src.constants.items(): + apply_constexpr_types(arg_types, list(path)[::-1], value) + + prototype = ASTFunction([], arg_types, src.attrs) + file_name, begin_line = get_jit_fn_file_line(fn) + # query function representation + from collections import namedtuple + leaves = filter(lambda v: len(v) == 1, src.constants) + constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves} + signature = src.signature + proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature) + generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon()) + generator.visit(fn.parse()) + module = generator.module + module.uses_tme = generator.uses_tme + # module takes ownership of the context + module.context = context + if not module.verify(): + if not fn.is_gluon(): + print(module) + raise RuntimeError("error encountered during parsing") + return module diff --git a/third_party/mthreads/python/triton/compiler/compiler.py b/third_party/mthreads/python/triton/compiler/compiler.py new file mode 100644 index 0000000000..e7979578e8 --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/compiler.py @@ -0,0 +1,513 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import Language +from ..backends.compiler import BaseBackend, GPUTarget +from .. import __version__, knobs +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key +from ..runtime.driver import driver +from ..tools.disasm import get_sass +from pathlib import Path +import re +import functools +import os +import time +import copy + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ptx": ptx_prototype_pattern, +} + +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +class ASTSource: + + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + self.fn = fn + self.language = Language.TRITON + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = dict() + if constexprs is not None: + for k, v in constexprs.items(): + k = (fn.arg_names.index(k), ) if isinstance(k, str) else k + assert isinstance(k, tuple) + self.constants[k] = v + self.attrs = attrs or dict() + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x) + constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())]) + key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context): + from .code_generator import ast_to_ttir + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path, context, backend): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.language = Language.TRITON + self.src = path.read_text() + ir.load_dialects(context) + backend.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context): + self.module.context = context + return self.module + + def parse_options(self): + if self.ext == "ttgir": + num_warps = self.module.get_int_attr("ttg.num-warps") + assert num_warps is not None, "Unable to parse ttg.num-warps attribute" + options = {'num_warps': num_warps} + num_ctas = self.module.get_int_attr("ttg.num-ctas") + if num_ctas is not None: + options['num_ctas'] = num_ctas + return options + return dict() + + +@functools.lru_cache() +def max_shared_mem(device): + return driver.active.utils.get_device_properties(device)["max_shared_mem"] + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx" or ext == "amdgcn": + return Path(full_name).read_text() + if ext == "cubin" or ext == "hsaco" or ext == "mubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if knobs.compilation.front_end_debugging: + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +class CompileTimer: + + def __init__(self) -> None: + self.start: float = time.time() + self.ir_initialization_end: float | None = None + self.lowering_stage_ends: list[tuple[str, float]] = [] + self.store_results_end: float | None = None + + def finished_ir_initialization(self) -> None: + self.ir_initialization_end = time.time() + + def stage_finished(self, stage_name: str) -> None: + self.lowering_stage_ends.append((stage_name, time.time())) + + def end(self) -> knobs.CompileTimes: + timestamp = time.time() + if self.ir_initialization_end is None: + self.ir_initialization_end = timestamp + else: + self.store_results_end = timestamp + + def delta(start: float, end: float | None) -> int: + if end is None: + return 0 + return int((end - start) * 1000000) + + lowering_stage_durations = [] + stage_start = self.ir_initialization_end + for stage_name, stage_end in self.lowering_stage_ends: + lowering_stage_durations.append((stage_name, delta(stage_start, stage_end))) + stage_start = stage_end + + return knobs.CompileTimes( + ir_initialization=delta(self.start, self.ir_initialization_end), + lowering_stages=lowering_stage_durations, + store_results=delta(stage_start, self.store_results_end), + ) + + +def _should_bypass_compilation_cache(target: GPUTarget) -> bool: + # Replace-IR knobs override compiler stages outside Triton's cache hash. + # Bypass cache hits so replacement file edits are always observed. + if target.backend != "musa": + return False + return bool(knobs.musa.replace_llir or knobs.musa.replace_mubin) + + +def compile(src, target=None, options=None, _env_vars=None): + compilation_listener = knobs.compilation.listener + if compilation_listener: + timer = CompileTimer() + + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + context = ir.context() + src = IRSource(src, context, backend) + + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + key = get_cache_key(src, backend, options, env_vars=env_vars) + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + key += inspect_stages_key + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = knobs.compilation.override + enable_ir_dump = knobs.compilation.dump_ir + store_only_binary = knobs.compilation.store_binary_only + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = knobs.compilation.always_compile + bypass_cache = _should_bypass_compilation_cache(target) + if not always_compile and metadata_path is not None and not bypass_cache: + # cache hit! + res = CompiledKernel(src, metadata_group, hash) + if compilation_listener: + compilation_listener( + src=src, + metadata=res.metadata._asdict(), + metadata_group=metadata_group, + times=timer.end(), + cache_hit=True, + ) + return res + + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + metadata["triton_version"] = __version__ + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options, src.language) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + codegen_fns = backend.get_codegen_implementation(options) + module_map = backend.get_module_map() + try: + module = src.make_ir(target, options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + + if ir_source: + ir_filename = f"{file_name}.{src.ext}" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + else: + ir_filename = f"{file_name}.source" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + + use_ir_loc = knobs.compilation.use_ir_loc + if ir_source and use_ir_loc: + module.create_location_snapshot(src.path) + print(f"Creating new locations for {src.path}") + + if compilation_listener: + timer.finished_ir_initialization() + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if fn_override_manager is None: + # Users can override kernels at scale by setting `ir_override` in autotune config + # without TRITON_KERNEL_OVERRIDE + if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"): + next_module = parse(ir_override, ext, context) + elif full_name := fn_override_manager.get_file(ir_filename): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json + if (not store_only_binary) or (ext in ("cubin", "hsaco", "mubin", "json")): + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if ext == "cubin": + sass = get_sass(next_module) + fn_dump_manager.put(sass, file_name + ".sass") + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + if compilation_listener: + timer.stage_finished(ext) + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + + # notify any listener + if compilation_listener: + compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(), + cache_hit=False) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target: GPUTarget) -> BaseBackend: + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self): + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +def _raise_error(err, *args, **kwargs): + raise copy.deepcopy(err) + + +class CompiledKernel: + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.metadata_group = metadata_group + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + self._run = None + + def _init_handles(self): + if self.module is not None: + return + + def raise_(err): + # clone the exception object so that the one saved in the closure + # of the partial function below doesn't get assigned a stack trace + # after the subsequent raise. otherwise, the CompiledKernel instance + # saved in the (global) kernel cache will keep references to all the + # locals in the traceback via the exception instance in the closure. + cloned_err = copy.deepcopy(err) + self._run = functools.partial(_raise_error, cloned_err) + raise err + + device = driver.active.get_current_device() + # create launcher + self._run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = max_shared_mem(device) + if self.metadata.shared > max_shared: + raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory")) + if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None: + # Use blackwell max tmem size for now, this should be moved in device properties + max_tmem_size = 512 # tmem size in number of columns + if self.metadata.tmem_size > max_tmem_size: + raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")) + if knobs.runtime.kernel_load_start_hook is not None: + knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash) + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + warp_size = driver.active.get_current_target().warp_size + if self.metadata.num_warps * warp_size > self.n_max_threads: + raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")) + if knobs.runtime.kernel_load_end_hook is not None: + knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash) + + @property + def run(self): + if self._run is None: + self._init_handles() + return self._run + + def launch_metadata(self, grid, stream, *args): + if knobs.runtime.launch_enter_hook is None: + return None + self._init_handles() + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)} + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args) + + return runner diff --git a/third_party/mthreads/python/triton/compiler/errors.py b/third_party/mthreads/python/triton/compiler/errors.py new file mode 100644 index 0000000000..39e6c4dfb0 --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/mthreads/python/triton/compiler/make_launcher.py b/third_party/mthreads/python/triton/compiler/make_launcher.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/mthreads/python/triton/errors.py b/third_party/mthreads/python/triton/errors.py new file mode 100644 index 0000000000..3a0a863553 --- /dev/null +++ b/third_party/mthreads/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/mthreads/python/triton/experimental/__init__.py b/third_party/mthreads/python/triton/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/mthreads/python/triton/experimental/gluon/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/__init__.py new file mode 100644 index 0000000000..7c62d7cb9d --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/__init__.py @@ -0,0 +1,6 @@ +from ._runtime import constexpr_function, jit +from triton.language.core import must_use_result +from . import nvidia +from . import amd + +__all__ = ["constexpr_function", "jit", "must_use_result", "nvidia", "amd"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/_compiler.py b/third_party/mthreads/python/triton/experimental/gluon/_compiler.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/mthreads/python/triton/experimental/gluon/_runtime.py b/third_party/mthreads/python/triton/experimental/gluon/_runtime.py new file mode 100644 index 0000000000..d98bb2098b --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/_runtime.py @@ -0,0 +1,102 @@ +from __future__ import annotations +from triton.compiler.compiler import ASTSource +from triton.backends.compiler import Language +from triton.runtime.jit import JITFunction, constexpr_function +from typing import TypeVar, Optional, Callable, Iterable, Union +from triton._C.libtriton import ir + +T = TypeVar("T") + +__all__ = ["constexpr_function", "jit"] + + +class GluonASTSource(ASTSource): + + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + super().__init__(fn, signature, constexprs, attrs) + self.language = Language.GLUON + self.ext = "ttgir" + + def make_ir(self, target, options, codegen_fns, module_map, context): + from triton.compiler.compiler import make_backend + from triton.compiler.code_generator import ast_to_ttir + + builder = ir.builder(context) + module = builder.create_module() + + # Assign module attributes eagerly, as they are needed to verify layouts + backend = make_backend(target) + target = backend.get_target_name(options) + + module.set_attr("ttg.target", builder.get_string_attr(target)) + module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps)) + module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas)) + module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size)) + + is_cuda = options.backend_name == "cuda" + if is_cuda and options.maxnreg is not None: + module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg)) + + module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map, module=module) + return module + + +class GluonJITFunction(JITFunction[T]): + + def create_binder(self): + result = super().create_binder() + self.ASTSource = GluonASTSource + return result + + def is_gluon(self): + return True + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + return GluonJITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator diff --git a/third_party/mthreads/python/triton/experimental/gluon/amd/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/amd/__init__.py new file mode 100644 index 0000000000..3271153da6 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/amd/__init__.py @@ -0,0 +1,3 @@ +from . import gfx1250 + +__all__ = ["gfx1250"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/amd/gfx1250.py b/third_party/mthreads/python/triton/experimental/gluon/amd/gfx1250.py new file mode 100644 index 0000000000..ae36b1b124 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/amd/gfx1250.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from typing import List, Any +from triton._utils import validate_block_shape +from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout + +__all__ = ["TensorDescriptor"] + + +@dataclass +class TensorDescriptor: + base: Any + shape: List[int] + strides: List[int] + block_shape: List[int] + layout: PaddedSharedLayout | SwizzledSharedLayout + padding: str = "zero" + + def __post_init__(self): + ndim = len(self.shape) + assert 1 <= ndim <= 5, f"Expected 1-5 dimensions but got {ndim} dimensions" + assert len(self.strides) == ndim, f"Expected {ndim} strides but got {len(self.strides)}" + assert len(self.block_shape) == ndim, \ + f"Expected block_shape to have {ndim} dimensions but got {len(self.strides)}" + validate_block_shape(self.block_shape) + assert self.strides[-1] == 1, "Last dimension must be contiguous" + assert isinstance(self.layout, (PaddedSharedLayout, SwizzledSharedLayout)), \ + "Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout" + if isinstance(self.layout, SwizzledSharedLayout): + assert self.layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout" + assert self.padding == "zero", "Only 'zero' padding is supported" + + @staticmethod + def from_tensor(tensor: Any, block_shape: List[int], layout: PaddedSharedLayout | SwizzledSharedLayout): + """ Create a TensorDescriptor object from a tensor. + + Args: + tensor (torch.Tensor): The input tensor. + block_shape (List[int]): The block shape of the tensor. + layout (PaddedSharedLayout | SwizzledSharedLayout): The layout of the tensor in shared memory. + + Returns: + tensor_descriptor: the created TensorDescriptor object + + """ + return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, layout) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/__init__.py new file mode 100644 index 0000000000..8f0b9bdc80 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/__init__.py @@ -0,0 +1,137 @@ +from ._core import ( + base_value, + base_type, + block_type, + broadcast, + cast, + constexpr, + dtype, + void, + int1, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float8e5, + float8e5b16, + float8e4nv, + float8e4b8, + float8e4b15, + float16, + bfloat16, + float32, + float64, + pointer_type, + shared_memory_descriptor, + tensor, + tuple, + tuple_type, + _unwrap_if_constexpr, + # API Functions + add, + allocate_shared_memory, + arange, + associative_scan, + assume, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bank_conflicts, + convert_layout, + device_assert, + device_print, + dot_fma, + expand_dims, + full, + fp4_to_fp, + gather, + num_warps, + num_ctas, + histogram, + inline_asm_elementwise, + join, + load, + map_elementwise, + max_constancy, + max_contiguous, + maximum, + minimum, + mul, + multiple_of, + num_programs, + permute, + program_id, + reduce, + reshape, + distributed_type, + shared_memory_descriptor_type, + set_auto_layout, + split, + static_assert, + static_print, + static_range, + store, + sub, + barrier, + to_linear_layout, + to_tensor, + warp_specialize, + where, +) +from ._layouts import ( + AutoLayout, + BlockedLayout, + SliceLayout, + DistributedLinearLayout, + DotOperandLayout, + NVMMADistributedLayout, + NVMMASharedLayout, + SwizzledSharedLayout, + PaddedSharedLayout, + SharedLinearLayout, + CoalescedLayout, +) +from ._math import ( + umulhi, + exp, + exp2, + fma, + log, + log2, + cos, + rsqrt, + sin, + sqrt, + sqrt_rn, + abs, + fdiv, + div_rn, + erf, + floor, + ceil, +) +from ._standard import ( + cdiv, + full_like, + max, + min, + ravel, + reduce_or, + sum, + xor_sum, + zeros, + zeros_like, +) + +from . import nvidia +from . import amd +from . import extra diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/_core.py b/third_party/mthreads/python/triton/experimental/gluon/language/_core.py new file mode 100644 index 0000000000..747ac9fb1a --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/_core.py @@ -0,0 +1,642 @@ +from __future__ import annotations +import inspect +import math +from typing import TypeVar, List, TYPE_CHECKING, Tuple +from functools import wraps +import warnings + +if TYPE_CHECKING: + from triton._C.libtriton.gluon_ir import GluonOpBuilder + from ._semantic import GluonSemantic + +from ._layouts import SharedLayout, DistributedLayout, BlockedLayout, DotOperandLayout, AutoLayout, CoalescedLayout +from triton._C.libtriton import ir +import triton.language.core as tl_core +from triton.language.core import ( + constexpr, + base_value, + base_type, + dtype, + block_type, # TODO: block type with layout info + pointer_type, + void, + int1, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float8e5, + float8e5b16, + float8e4nv, + float8e4b8, + float8e4b15, + float16, + bfloat16, + float32, + float64, + _unwrap_if_constexpr, + _unwrap_shape, + static_range, + tensor, + tuple, + tuple_type, +) + +# We define __all__ only to appease the python linter, these are not used in +# this file but we want to import them anyway so they are importable from here. +__all__ = [ + "constexpr", + "pointer_type", + "void", + "int1", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float8e5", + "float8e5b16", + "float8e4nv", + "float8e4b8", + "float8e4b15", + "float16", + "bfloat16", + "float32", + "float64", + "distributed_type", + "shared_memory_descriptor_type", + "static_range", + "tuple", + "tuple_type", + "num_ctas", +] + +T = TypeVar("T") + +# TODO: split these +GLUON_BUILTIN = "__triton_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_semantic" not in kwargs or kwargs["_semantic"] is None: + raise ValueError("Did you forget to add @triton.gluon.jit ? " + "(`_semantic` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, GLUON_BUILTIN, True) + wrapper.signature = inspect.signature(fn) + + return wrapper + + +# Explicitly import forwarded Triton language symbols so mypy sees them. +add = builtin(tl_core.add) +associative_scan = builtin(tl_core.associative_scan) +assume = builtin(tl_core.assume) +atomic_add = builtin(tl_core.atomic_add) +atomic_and = builtin(tl_core.atomic_and) +atomic_cas = builtin(tl_core.atomic_cas) +atomic_max = builtin(tl_core.atomic_max) +atomic_min = builtin(tl_core.atomic_min) +atomic_or = builtin(tl_core.atomic_or) +atomic_xchg = builtin(tl_core.atomic_xchg) +atomic_xor = builtin(tl_core.atomic_xor) +broadcast = builtin(tl_core.broadcast) +cast = builtin(tl_core.cast) +device_assert = builtin(tl_core.device_assert) +device_print = builtin(tl_core.device_print) +expand_dims = builtin(tl_core.expand_dims) +gather = builtin(tl_core.gather) +inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise) +join = builtin(tl_core.join) +load = builtin(tl_core.load) +map_elementwise = builtin(tl_core.map_elementwise) +max_constancy = builtin(tl_core.max_constancy) +max_contiguous = builtin(tl_core.max_contiguous) +maximum = builtin(tl_core.maximum) +minimum = builtin(tl_core.minimum) +mul = builtin(tl_core.mul) +multiple_of = builtin(tl_core.multiple_of) +num_programs = builtin(tl_core.num_programs) +permute = builtin(tl_core.permute) +program_id = builtin(tl_core.program_id) +reduce = builtin(tl_core.reduce) +reshape = builtin(tl_core.reshape) +split = builtin(tl_core.split) +static_assert = builtin(tl_core.static_assert) +static_print = builtin(tl_core.static_print) +store = builtin(tl_core.store) +sub = builtin(tl_core.sub) +to_tensor = builtin(tl_core.to_tensor) +where = builtin(tl_core.where) + + +class distributed_type(block_type): + + def __init__(self, element_ty: dtype, shape: List[int], layout): + layout = _unwrap_if_constexpr(layout) + shape = _unwrap_if_constexpr(shape) + super().__init__(element_ty, shape) + self.layout = layout + self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>" + assert isinstance(layout, DistributedLayout), "tensor layout must be a DistributedLayout" + if not isinstance(layout, (AutoLayout, CoalescedLayout)): + assert len( + shape + ) == layout.rank, f"tensor shape and layout rank mismatch: shape={shape}, layout={layout}, shape rank={len(shape)}, layout rank={layout.rank}" + + def to_ir(self, builder: ir.builder) -> ir.type: + elem_ty = self.element_ty.to_ir(builder) + layout = self.layout._to_ir(builder) + return builder.get_distributed_ty(elem_ty, self.shape, layout) + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = "_".join(map(str, self.shape)) + layout = self.layout.mangle() + return f"{elt}S{shape}SL{layout}L" + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return distributed_type(scalar_ty, self.shape, self.layout) + + def __eq__(self, other) -> bool: + if not isinstance(other, distributed_type): + return False + return super().__eq__(other) and self.layout == other.layout + + +class shared_memory_descriptor_type(base_type): + + def __init__(self, element_ty, shape, layout, alloc_shape): + shape = _unwrap_if_constexpr(shape) + alloc_shape = _unwrap_if_constexpr(alloc_shape) + layout = _unwrap_if_constexpr(layout) + self.element_ty = element_ty + self.shape = shape + self.layout = layout + self.alloc_shape = alloc_shape + assert isinstance(layout, SharedLayout) + + def to_ir(self, builder: GluonOpBuilder) -> None: + return builder.get_shared_mem_desc_ty( + self.element_ty.to_ir(builder), + self.shape, + self.layout._to_ir(builder), + self.alloc_shape, + ) + + def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]: + value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def __str__(self) -> str: + return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>" + + def __eq__(self, other) -> bool: + return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout + and self.alloc_shape == other.alloc_shape) + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + shape_str = "_".join([str(s) for s in self.shape]) + alloc_shape_str = "_".join([str(s) for s in self.alloc_shape]) + return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{alloc_shape_str}ASMD" + + +class shared_memory_descriptor(base_value): + """ + Represents a handle to a shared memory allocation in Gluon IR. + """ + + def __init__(self, handle, element_ty, shape, layout, alloc_shape): + self.handle = handle + self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def dtype(self): + return self.type.element_ty + + @property + def shape(self): + return self.type.shape + + @property + def rank(self): + return len(self.shape) + + @property + def numel(self) -> int: + return math.prod(self.shape) + + @property + def layout(self): + return self.type.layout + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, layout, _semantic: GluonSemantic = None) -> tensor: + """ + Load a tensor from shared memory. + + Args: + layout (DistributedLayout): The destination layout of the tensor. + + Returns: + tensor: A Gluon tensor containing the loaded data. + """ + layout = _unwrap_if_constexpr(layout) + return _semantic.shared_load(self, layout) + + @builtin + def store(self, value, _semantic: GluonSemantic = None) -> None: + """ + Store a tensor into shared memory. + + Args: + value (tensor): The tensor whose contents to store. + """ + return _semantic.shared_store(self, value) + + @builtin + def gather(self, indices, axis, _semantic: GluonSemantic = None) -> tensor: + """ + Gather elements from shared memory along a specified axis using an indices tensor. + + For each output position I, the operation reads from src where the coordinate at + the gather axis is replaced by indices[I]: + result[I] = src[I[0], ..., indices[I], ..., I[n]] + + Args: + indices (tensor): Tensor specifying which indices to gather along the axis. + axis (int): The axis along which to gather values. + + Returns: + tensor: Gluon tensor with the gathered elements (same shape as indices). + """ + indices = _unwrap_if_constexpr(indices) + axis = _unwrap_if_constexpr(axis) + return _semantic.shared_gather(self, indices, axis) + + @builtin + def scatter(self, values, indices, axis, _semantic: GluonSemantic = None): + """ + Scatter elements to shared memory along a specified axis using an indices tensor. + + For each input position I, the operation writes to dst where the coordinate at + the scatter axis is replaced by indices[I]: + dst[I[0], ..., indices[I], ..., I[n]] = values[I] + + Args: + values (tensor): Tensor with values to scatter (same shape as indices). + indices (tensor): Tensor specifying which indices to scatter to along the axis. + axis (int): The axis along which to scatter values. + """ + values = _unwrap_if_constexpr(values) + indices = _unwrap_if_constexpr(indices) + axis = _unwrap_if_constexpr(axis) + return _semantic.shared_scatter(self, values, indices, axis) + + def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + """ + Create a subview of shared memory by slicing along a given dimension. + + Args: + start (int): The starting index of the slice. + length (int): The length of the slice. + dim (int): The dimension to slice (default: 0). + + Returns: + shared_memory_descriptor: Descriptor for the sliced subview. + """ + start = _unwrap_if_constexpr(start) + length = _unwrap_if_constexpr(length) + dim = _unwrap_if_constexpr(dim) + return _semantic.memdesc_slice(self, start, length, dim) + + @builtin + def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + """ + Create a subview of shared memory by indexing along the first dimension. + + Args: + index (int): The index at which to take the subview. + + Returns: + shared_memory_descriptor: Descriptor for the indexed subview. + """ + index = _unwrap_if_constexpr(index) + return _semantic.memdesc_index(self, index) + + @builtin + def permute(self, order, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + """ + Permute the dimensions of the shared memory descriptor. + + Args: + order (List[int]): The new ordering of dimensions. + + Returns: + shared_memory_descriptor: Descriptor with permuted dimensions. + """ + order = [_unwrap_if_constexpr(o) for o in order] + return _semantic.memdesc_trans(self, order) + + @builtin + def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + """ + Reshape the shared memory descriptor to a new shape and layout. + + Args: + shape (List[int]): The target shape. + + Returns: + shared_memory_descriptor: Descriptor with the new shape and layout. + """ + shape = [_unwrap_if_constexpr(s) for s in shape] + + return _semantic.memdesc_reshape(self, shape) + + @builtin + def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + """ + Reinterpret the shared memory descriptor as a different dtype, shape, or layout. + + Args: + dtype (dtype): The new data type. + shape (List[int]): The new shape. + layout (SharedLayout): The new layout. + + Returns: + shared_memory_descriptor: Descriptor with updated type and layout. + """ + dtype = _unwrap_if_constexpr(dtype) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + + return _semantic.memdesc_reinterpret(self, dtype, shape, layout) + + @builtin + def _keep_alive(self, _semantic: GluonSemantic = None) -> None: + """ + Dummy use to keep the shared memory descriptor alive. + """ + return _semantic.shared_dealloc(self) + + +@builtin +def arange(start, end, layout=None, _semantic=None): + """ + Generate a sequence tensor with values in [start, end) using a specified layout. + + Args: + start (int): Inclusive start of the sequence. + end (int): Exclusive end of the sequence. + layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout. + + Returns: + tensor: A 1D tensor containing sequential values. + """ + start = _unwrap_if_constexpr(start) + end = _unwrap_if_constexpr(end) + layout = _unwrap_if_constexpr(layout) + return _semantic.arange(start, end, layout) + + +@builtin +def convert_layout(value, layout, assert_trivial=False, _semantic=None): + """ + Convert a tensor to a different distributed layout. + + Args: + value (tensor): The input tensor. + layout (DistributedLayout): The target layout. + assert_trivial (bool): If True, asserts that the conversion is trivial (no data movement). + + Returns: + tensor: The tensor with the new layout. + """ + layout = _unwrap_if_constexpr(layout) + return _semantic.convert_layout(value, layout, assert_trivial) + + +@builtin +def full(shape, value, dtype, layout=None, _semantic=None): + """ + Create a tensor filled with a scalar value, with specified shape, dtype, and layout. + + Args: + shape (Sequence[int]): The shape of the tensor. + value (int or float): The fill value. + dtype (dtype): The data type for the tensor. + layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout(). + + Returns: + tensor: A tensor where every element equals value. + """ + shape = _unwrap_shape(shape) + value = _unwrap_if_constexpr(value) + dtype = _unwrap_if_constexpr(dtype) + layout = _unwrap_if_constexpr(layout) + return _semantic.full(shape, value, dtype, layout) + + +@builtin +def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None): + """ + Compute a histogram of a 1D integer tensor. + + Args: + input (tensor): 1D tensor of integer values. + num_bins (int): Number of bins. Bins have width 1 and start at 0. + mask (Optional[tensor]): Boolean mask to exclude elements when False. + layout (DistributedLayout): Destination layout of the output histogram. + + Returns: + tensor: 1D int32 tensor of length `num_bins` with the requested layout. + """ + num_bins = _unwrap_if_constexpr(num_bins) + layout = _unwrap_if_constexpr(layout) + if mask is not None: + mask = _semantic.to_tensor(mask) + return _semantic.histogram(input, num_bins, mask, layout) + + +@builtin +def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor: + """ + Allocate shared memory for a tensor with the given element type, shape, and layout. + + Args: + element_ty (dtype): The element data type. + shape (Sequence[int]): The dimensions of the shared memory. + layout (SharedLayout): The shared memory layout. + value (tensor, optional): Initial value to copy into shared memory. + + Returns: + shared_memory_descriptor: Descriptor for the allocated memory. + """ + element_ty = _unwrap_if_constexpr(element_ty) + shape = _unwrap_if_constexpr(shape) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + return _semantic.allocate_shared(element_ty, shape, layout, value) + + +@builtin +def set_auto_layout(value, layout, _semantic=None): + """ + Set a tensor with AutoLayout to a concrete layout + + Args: + value (tensor): The input tensor. + layout (DistribtedLayout): The target layout. + + Returns: + tensor: The tensor with the new layout. + """ + layout = _unwrap_if_constexpr(layout) + return _semantic.set_auto_layout(value, layout) + + +@builtin +def fp4_to_fp(src, elem_type, axis, _semantic=None): + """ + Upcast a tensor from fp4 (e2m1) to another floating point type. + """ + axis = _unwrap_if_constexpr(axis) + elem_type = _unwrap_if_constexpr(elem_type) + return _semantic.fp4_to_fp(src, elem_type, axis) + + +@builtin +def warp_specialize(functions_and_args, worker_num_warps, worker_num_regs=None, _semantic=None, _generator=None): + """ + Create a warp-specialized execution region, partitioning work across warps. + + This forks the current execution into a "default partition" and an arbitrary number of + "worker partitons". The default partition is executed in the same :code:`num_warps` warps as + the parent region, and may accept tensor arguments and return tensors. Worker partitions are + executed in additional warps, which sit idle while executing the parent region. + + Note that calling warp_specialize recursively is not supported. + + Args: + functions_and_args (List[Tuple[Callable, Any]]): List of functions and arguments for each partition. The first of which is the default partition. + worker_num_warps (List[int]): Number of warps used for each worker partition. + worker_num_regs (List[int], optional): Number of registers for each worker partition. + If not None, will be used by backend for dynamic register reallocation. + + Returns: + Tuple[Any, ...]: Results from the default partition. + """ + worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps] + if worker_num_regs is not None: + worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs] + return _semantic.warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _generator) + + +@builtin +def num_warps(_semantic=None, _generator=None): + """ + Returns the number of warps that execute the current context, including in warp-specialized regions. + """ + return _semantic.num_warps(_generator) + + +@builtin +def num_ctas(_semantic=None): + """ + Returns the number of CTAs in the current kernel + """ + return _semantic.num_ctas() + + +@builtin +def barrier(*, cluster: bool = False, _semantic=None): + """ + Insert a barrier to synchronize threads within a CTA, or across a cluster. + + Args: + cluster (bool): Whether to synchronize across the CTA cluster. + """ + cluster = _unwrap_if_constexpr(cluster) + num_ctas = _unwrap_if_constexpr(_semantic.num_ctas()) + if num_ctas == 1 or not cluster: + return _semantic.debug_barrier() + _semantic.builder.create_cluster_sync() + + +@builtin +def bank_conflicts(distr_ty, shared_ty, _semantic=None) -> int: + """ + Count the bank conflicts per wavefront of each instruction generated when + reading/writing the distributed tensor from/to the shared memory descriptor + using ld.shared/st.shared instructions. + + We define a bank conflict of N to be the excess number of memory accesses that each + wavefront needs to access the shared memory descriptor. When one uses no ld/st + vectorization, this is equal to t he number of excess memory accesses per instruction. + + Args: + distr_ty (distributed_type): The distributed tensor. + shared_ty (shared_memory_descriptor_type): The shared memory descriptor. + + Returns: + int: The number of bank conflicts. + """ + distr_ty = _unwrap_if_constexpr(distr_ty) + shared_ty = _unwrap_if_constexpr(shared_ty) + return _semantic.bank_conflicts(distr_ty, shared_ty) + + +@builtin +def to_linear_layout(layout, shape, _semantic=None): + layout = _unwrap_if_constexpr(layout) + shape = _unwrap_shape(shape) + return _semantic.to_linear_layout(layout, shape) + + +@builtin +def dot_fma(a, b, acc, _semantic=None): + assert isinstance(a, tensor), "a must be a tensor" + assert isinstance(b, tensor), "b must be a tensor" + assert isinstance(acc, tensor), "acc must be a tensor" + + mma_layout = acc.type.layout + assert isinstance(mma_layout, BlockedLayout), "acc must have a BlockedLayout" + assert isinstance(a.type.layout, DotOperandLayout), "a must have a DotOperandLayout" + assert isinstance(b.type.layout, DotOperandLayout), "b must have a DotOperandLayout" + assert a.type.layout.parent == mma_layout, "a's parent layout must be the same as acc's layout" + assert b.type.layout.parent == mma_layout, "b's parent layout must be the same as acc's layout" + assert a.type.layout.operand_index == 0, "a's operand index must be 0" + assert b.type.layout.operand_index == 1, "b's operand index must be 1" + + M, N = acc.shape + K = a.shape[1] + if M * N * K > 2**19: + warnings.warn(f"Large dot FMA instruction size {M}x{N}x{K} may have slow compile times") + + handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle + return tensor(handle, acc.type) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/_layouts.py b/third_party/mthreads/python/triton/experimental/gluon/language/_layouts.py new file mode 100644 index 0000000000..d2a2780c56 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/_layouts.py @@ -0,0 +1,704 @@ +from dataclasses import dataclass, field +import itertools +import math +from typing import List + +from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type +from triton.runtime.jit import constexpr_function +from triton._C.libtriton import gluon_ir + + +class DistributedLayout: + """ + Base class for distributed memory layouts in Gluon IR. + """ + + @property + def type(self): + return constexpr_type(self) + + @property + def rank(self): + raise NotImplementedError("DistributedLayout subclasses must define rank") + + def format_tensor_view(self, shape: list[int]) -> str: + return gluon_ir.get_layout_view(self, [_unwrap_if_constexpr(s) for s in shape], False) + + def format_hardware_view(self, shape: list[int]) -> str: + return gluon_ir.get_layout_view(self, [_unwrap_if_constexpr(s) for s in shape], True) + + +@dataclass(frozen=True) +class AutoLayout(DistributedLayout): + + def _to_ir(self, builder): + return builder.get_auto_layout() + + def mangle(self): + return "AL" + + @property + def rank(self): + raise ValueError("AutoLayout has no rank") + + +@dataclass(frozen=True) +class CoalescedLayout(DistributedLayout): + + def _to_ir(self, builder): + return builder.get_coalesced_layout() + + def mangle(self): + return "CL" + + @property + def rank(self): + raise ValueError("CoalescedLayout has no rank") + + +@dataclass(frozen=True) +class BlockedLayout(DistributedLayout): + """ + Represents a blocked layout, partitioning a tensor across threads, warps, and CTAs. + + Args: + size_per_thread (List[int]): Number of elements per thread per dimension. + threads_per_warp (List[int]): Number of threads per warp per dimension. + warps_per_cta (List[int]): Number of warps per CTA per dimension. + order (List[int]): The ordering of dimensions for partitioning. + cga_layout (Optional[List[List[int]]]): Bases describing how CTAs tile each dimension. + """ + size_per_thread: List[int] + threads_per_warp: List[int] + warps_per_cta: List[int] + order: List[int] + cga_layout: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread)) + super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp)) + super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) + super().__setattr__("order", _unwrap_if_constexpr(self.order)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(self.cga_layout)) + + rank = len(self.size_per_thread) + assert len(self.threads_per_warp) == rank + assert len(self.warps_per_cta) == rank + assert len(self.order) == rank + + def _to_ir(self, builder): + return builder.get_blocked_layout( + self.size_per_thread, + self.threads_per_warp, + self.warps_per_cta, + self.order, + self.cga_layout, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + size_per_thread = stringify(self.size_per_thread) + threads_per_warp = stringify(self.threads_per_warp) + warps_per_cta = stringify(self.warps_per_cta) + order = stringify(self.order) + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"B{size_per_thread}_{threads_per_warp}_{warps_per_cta}_{order}_{cga_layout}B" + + def __hash__(self): + return hash((tuple(self.size_per_thread), tuple(self.threads_per_warp), tuple(self.warps_per_cta), + tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout))) + + @property + def rank(self): + return len(self.order) + + +@dataclass(frozen=True) +class SliceLayout(DistributedLayout): + """ + Represents a layout corresponding to slicing a distributed tensor along one dimension. + + Args: + dim (int): The dimension index to slice. + parent (DistributedLayout): The parent layout before slicing. + """ + dim: int + parent: DistributedLayout + + def __post_init__(self): + super().__setattr__("dim", _unwrap_if_constexpr(self.dim)) + super().__setattr__("parent", _unwrap_if_constexpr(self.parent)) + + def _to_ir(self, builder): + return builder.get_slice_layout( + self.dim, + self.parent._to_ir(builder), + ) + + def mangle(self) -> str: + return f"SL{self.dim}_{self.parent.mangle()}SL" + + def __hash__(self): + return hash((self.dim, self.parent)) + + @property + def rank(self): + return self.parent.rank - 1 + + @property + def cga_layout(self): + parent_cga_layout = self.parent.cga_layout + if not parent_cga_layout: + return [] + + rank = self.parent.rank + assert 0 <= self.dim < rank + return [basis[:self.dim] + basis[self.dim + 1:] for basis in parent_cga_layout] + + +@dataclass(frozen=True) +class DistributedLinearLayout(DistributedLayout): + """ + Represents a linear distributed layout with explicit bases at register, lane, warp, and block levels. + See: https://arxiv.org/abs/2505.23819 for reference. + + Args: + reg_bases (List[List[int]]): Bases for register-level distribution. + lane_bases (List[List[int]]): Bases for lane-level distribution. + warp_bases (List[List[int]]): Bases for warp-level distribution. + block_bases (List[List[int]]): Bases for block-level distribution. + shape (List[int]): The tensor global shape. + """ + reg_bases: List[List[int]] + lane_bases: List[List[int]] + warp_bases: List[List[int]] + block_bases: List[List[int]] + shape: List[int] + + def __post_init__(self): + super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases)) + super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases)) + super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases)) + super().__setattr__("block_bases", _unwrap_shape(self.block_bases)) + super().__setattr__("shape", _unwrap_shape(self.shape)) + + rank = len(self.shape) + + for basis in self.reg_bases: + assert len(basis) == rank + for basis in self.lane_bases: + assert len(basis) == rank + for basis in self.warp_bases: + assert len(basis) == rank + for basis in self.block_bases: + assert len(basis) == rank + + def _to_ir(self, builder): + return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases, + self.shape) + + def mangle(self): + return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL" + + def __hash__(self): + return hash(( + tuple(map(tuple, self.reg_bases)), + tuple(map(tuple, self.lane_bases)), + tuple(map(tuple, self.warp_bases)), + tuple(map(tuple, self.block_bases)), + tuple(self.shape), + )) + + @property + def rank(self): + return len(self.shape) + + +@dataclass(frozen=True) +class DotOperandLayout(DistributedLayout): + """ + Represents a layout for a dot operand. + + Args: + operand_index (int): 0 for LHS and 1 for RHS of the dot operation. + parent (DistributedLayout): The parent layout, representing the MMA. + k_width (int): Number of elements per 32-bits. + """ + operand_index: int + parent: DistributedLayout + k_width: int + + def __post_init__(self): + super().__setattr__("operand_index", _unwrap_if_constexpr(self.operand_index)) + super().__setattr__("parent", _unwrap_if_constexpr(self.parent)) + super().__setattr__("k_width", _unwrap_if_constexpr(self.k_width)) + + def _to_ir(self, builder): + return builder.get_dot_operand_layout(self.operand_index, self.parent._to_ir(builder), self.k_width) + + def mangle(self) -> str: + return f"DO{self.operand_index}_{self.parent.mangle()}_{self.k_width}DO" + + def __hash__(self): + return hash((self.operand_index, self.parent, self.k_width)) + + @property + def rank(self): + return self.parent.rank + + @property + def cga_layout(self): + parent_cga_layout = _unwrap_if_constexpr(getattr(self.parent, "cga_layout", [])) or [] + if not parent_cga_layout: + return [] + + rank = self.parent.rank + assert all(len(basis) == rank for basis in parent_cga_layout) + + k_dim = rank - 1 if self.operand_index == 0 else rank - 2 + assert 0 <= k_dim < rank + + derived = [] + for basis in parent_cga_layout: + new_basis = list(basis) + new_basis[k_dim] = 0 + derived.append(new_basis) + return derived + + +@dataclass(frozen=True, eq=True) +class NVMMADistributedLayout(DistributedLayout): + """ + Represents a layout for NVIDIA MMA (tensor core) operations. + + Args: + version (List[int]): Version identifier for the MMA instruction. + warps_per_cta (List[int]): Number of warps per CTA. + instr_shape (List[int]): Instruction shape for MMA. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. + """ + version: List[int] + warps_per_cta: List[int] + instr_shape: List[int] + cga_layout: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + super().__setattr__("version", _unwrap_if_constexpr(self.version)) + super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) + super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(self.cga_layout)) + + def _to_ir(self, builder): + return builder.get_mma_layout( + self.version, + self.warps_per_cta, + self.cga_layout, + self.instr_shape, + ) + + def mangle(self) -> str: + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{cga_layout}_MMA" + + def __hash__(self): + return hash((tuple(self.version), tuple(self.warps_per_cta), tuple(self.instr_shape), + tuple(tuple(vec) for vec in self.cga_layout))) + + @property + def rank(self): + return len(self.warps_per_cta) + + +class SharedLayout: + """ + Base class for shared memory layouts in Gluon IR. + """ + + @property + def type(self): + return constexpr_type(self) + + def format_tensor_view(self, shape: list[int]) -> str: + return gluon_ir.get_layout_view(self, [_unwrap_if_constexpr(s) for s in shape], False) + + def format_hardware_view(self, shape: list[int]) -> str: + return gluon_ir.get_layout_view(self, [_unwrap_if_constexpr(s) for s in shape], True) + + +@constexpr_function +def _get_shape_per_cta(shape, cga_layout): + if not cga_layout: + return shape + shape_per_cta = list(shape) + rank = len(cga_layout[0]) + cga_shape = [0] * rank + for basis in cga_layout: + assert len(basis) == rank + for i in range(rank): + cga_shape[i] = max(cga_shape[i], basis[i]) + # The shape is the largest stride * 2, or 1 if the stride was always zero + for i in range(rank): + if cga_shape[i] == 0: + cga_shape[i] = 1 + else: + cga_shape[i] *= 2 + for dim in range(rank): + assert shape_per_cta[dim] % cga_shape[dim] == 0, f"Shape {shape} is not divisible by CGA layout {cga_layout}" + shape_per_cta[dim] //= cga_shape[dim] + return shape_per_cta + + +@dataclass(frozen=True) +class NVMMASharedLayout(SharedLayout): + """ + Represents a layout for shared memory suitable for NVIDIA MMA operations. + + Args: + swizzle_byte_width (int): Width in bytes for swizzling. + element_bitwidth (int): Bitwidth of element type. + rank (int): Rank of the tensor. + transposed (bool): Whether the layout is transposed. + fp4_padded (bool): Whether FP4 padding is used. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. + """ + swizzle_byte_width: int + element_bitwidth: int + rank: int = 2 + transposed: bool = False + fp4_padded: bool = False + cga_layout: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width)) + super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth)) + super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) + super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded)) + + # TODO: Make rank optional and check that (rank or cga_layout) + cga_layout = self.cga_layout or [] + if cga_layout: + assert len(cga_layout[0]) == self.rank + + super().__setattr__("rank", _unwrap_if_constexpr(self.rank)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(cga_layout)) + + assert self.element_bitwidth in [8, 16, 32, 64] + assert self.swizzle_byte_width in [0, 32, 64, 128] + + if self.fp4_padded: + assert self.swizzle_byte_width == 128, "fp4_padded only supports 128 byte swizzling" + assert self.element_bitwidth == 8, "fp4_padded is only supported for element_bitwidth=8" + + def _to_ir(self, builder): + return builder.get_nvmma_shared_layout( + self.swizzle_byte_width, + self.element_bitwidth, + self.transposed, + self.fp4_padded, + self.cga_layout, + self.rank, + ) + + @staticmethod + @constexpr_function + def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, cga_layout=None): + """Returns an NVMMASharedLayout with default swizzling for a given shape. + + This picks the largest swizzle pattern compatible with the shape, which + allows emitting the fewest TMA or MMA messages. + """ + packing_factor = 2 if fp4_padded else 1 + shape_per_cta = block_shape if cga_layout is None else _get_shape_per_cta(block_shape, cga_layout) + rank = len(block_shape) + if transposed: + shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1] + contig_dim_size = shape_per_cta[-1] * packing_factor + contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8 + if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0: + swizzle_byte_width = 128 + elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0: + swizzle_byte_width = 64 + elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0: + swizzle_byte_width = 32 + else: + swizzle_byte_width = 0 + + flatten_outer_dim = 1 + for size in shape_per_cta[:-1]: + flatten_outer_dim *= size + if len(block_shape) < 2 or flatten_outer_dim < 8: + swizzle_byte_width = 0 + + return NVMMASharedLayout( + swizzle_byte_width=swizzle_byte_width, + element_bitwidth=dtype.primitive_bitwidth, + rank=rank, + transposed=transposed, + fp4_padded=fp4_padded, + cga_layout=cga_layout, + ) + + def mangle(self) -> str: + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_{cga_layout}_NVMMA" + + def __hash__(self): + return hash((self.swizzle_byte_width, self.element_bitwidth, self.rank, self.transposed, self.fp4_padded, + tuple(tuple(vec) for vec in self.cga_layout) if self.cga_layout else None)) + + +@dataclass(frozen=True, eq=True) +class SwizzledSharedLayout(SharedLayout): + """ + Represents a generic swizzled shared memory layout. + + Args: + vec (int): Vector width for swizzling. + per_phase (int): Elements per swizzle phase. + max_phase (int): Maximum number of swizzle phases. + order (List[int]): Dimension ordering for swizzling. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. + """ + vec: int + per_phase: int + max_phase: int + order: List[int] + cga_layout: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + super().__setattr__("vec", _unwrap_if_constexpr(self.vec)) + super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase)) + super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase)) + super().__setattr__("order", _unwrap_if_constexpr(self.order)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(self.cga_layout)) + + def _to_ir(self, builder): + return builder.get_swizzled_shared_layout( + self.vec, + self.per_phase, + self.max_phase, + self.order, + self.cga_layout, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{cga_layout}_SSS" + + def __hash__(self): + return hash( + (self.vec, self.per_phase, self.max_phase, tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout))) + + +@dataclass(frozen=True, eq=True) +class PaddedSharedLayout(SharedLayout): + """ + Represents a layout for the access to shared memory. Compared to SwizzledSharedLayout, + it combined padding and element reordering via linear transformation (e.g. row permutation) + to avoid shared memory bank conflicts. After every interval tensor elements, the + corresponding number of padding elements are inserted. If a position corresponds to + multiple intervals, the padding amounts are summed. + + In the following example of a tensor, + `eM` represents original elements in the and `pN` represents padded element. + + Before padding, the shared memory looks like: + [e0, e1, + e2, e3, + e4, e5, + e6, e7, + ...] + + After padding with interval-padding list [[2, 1], [4, 2]] with an identity remapping, + the shared memory will be + [e0, e1, p0, + e2, e3, p1, p2, p3, + e4, e5, p4, + e6, e7, p5, p6, p7, + ...] + + Furthermore this encoding allows for a linear remapping from the 1-D shared + memory offset to logical n-D tensor elements. The remapping is given in the form + of linear bases mapping from offset to [dim0, dim1...dimN-1]. + See LinearLayout.h for more details how linear layouts are applied to remap + elements. + Some concrete examples using `xN` and `yN` to mean the logical n-D tensor elements + and `pN` to mean padding: + + After padding for shape = [8] with interval-padding list [[2, 2]], offset_bases = [[2], [1]] and block_bases = []: + [x0, x2, p0 p1, x1, x3] + + After padding for shape = [8, 4] with interval_padding_pairs = [[8, 1]], offset_bases = [[0, 1], [0, 2], /*gap, stride by 2 rows*/[2, 0], [4, 0], [1, 0]]] and block_bases = []: + [ + x0y0, x0y1, x0y2, x0y3, + x2y0, x2y1, x2y2, x2y3, + p0, + x4y0, x4y1, x4y2, x4y3, + x6y0, x6y1, x6y2, x6y3, + p1, + x1y0, x1y1, x1y2, x1y3, + x3y0, x3y1, x3y2, x3y3, + p2, + x5y0, x5y1, x5y2, x5y3, + x7y0, x7y1, x7y2, x7y3, + ] + + Args: + interval_padding_pairs (List[int]): List of [interval, padding] pair and both interval and padding must be powers of 2. + offset_bases (List[int]): Bases for shared memory offsets + block_bases (List[List[int]]): Bases for block-level shared memory offsets. + shape (List[int]): n-D logical shared memory shape + """ + interval_padding_pairs: List[List[int]] + offset_bases: List[List[int]] + block_bases: List[List[int]] + shape: List[int] + + def __post_init__(self): + super().__setattr__("interval_padding_pairs", _unwrap_shape(self.interval_padding_pairs)) + super().__setattr__("offset_bases", _unwrap_shape(self.offset_bases)) + super().__setattr__("block_bases", _unwrap_shape(self.block_bases)) + super().__setattr__("shape", _unwrap_shape(self.shape)) + + rank = len(self.shape) + + for basis in self.offset_bases: + assert len(basis) == rank + for basis in self.block_bases: + assert len(basis) == rank + + self.verify() + + def _to_ir(self, builder): + intervals, paddings = zip(*self.interval_padding_pairs) + return builder.get_padded_shared_layout(intervals, paddings, self.offset_bases, self.block_bases, self.shape) + + def mangle(self) -> str: + return f"PaddedShared_{self.interval_padding_pairs}_{self.offset_bases}_{self.block_bases}_{self.shape}_PaddedShared" + + def verify(self): + pairs = self.interval_padding_pairs + assert len(pairs) > 0, "PaddedSharedLayout interval_padding_pairs must have at least one interval-padding pair" + assert all(len(pair) == 2 for pair in pairs) + intervals, paddings = zip(*pairs) + + unique_intervals = list(set(intervals)) + assert len(unique_intervals) == len(intervals) + + is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0 + assert all(is_power_of_2(n) for n in intervals), "PaddedSharedLayout interval values must all be power of two" + assert all(is_power_of_2(n) for n in paddings), "PaddedSharedLayout padding values must all be power of two" + + rank = len(self.shape) + assert rank > 0, "PaddedSharedLayout order must not be empty" + + @staticmethod + @constexpr_function + def with_identity_for(interval_padding_pairs, shape, order): + """Returns a PaddedSharedLayout with the given interval and padding pairs and an identity mapping as the linear component for the given shape and order. + """ + assert len(shape) == len(order) + is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0 + assert all(is_power_of_2(n) for n in shape) + + rank = len(shape) + # Create a idendity mapping based on shape + order + offset_bases = [] + for dim in order: + for basis in range(int(math.log2(shape[dim]))): + offset_bases.append([1 << basis if i == dim else 0 for i in range(rank)]) + + return PaddedSharedLayout(interval_padding_pairs, offset_bases, [], shape) + + def __hash__(self): + return hash((tuple(map(tuple, self.interval_padding_pairs)), tuple(map(tuple, self.offset_bases)), + tuple(map(tuple, self.block_bases)), tuple(self.shape))) + + +@dataclass(frozen=True) +class SharedLinearLayout(SharedLayout): + """Represents a shared memory layout defined via an explicit LinearLayout.""" + + offset_bases: List[List[int]] + block_bases: List[List[int]] = field(default_factory=list) + alignment: int = 16 + + def __post_init__(self): + super().__setattr__("offset_bases", _unwrap_shape(self.offset_bases)) + super().__setattr__("block_bases", _unwrap_shape(self.block_bases)) + super().__setattr__("alignment", _unwrap_if_constexpr(self.alignment)) + + assert len(self.offset_bases) != 0, "SharedLinearLayout offset_bases must not be empty" + rank = len(self.offset_bases[0]) + assert rank > 0, "SharedLinearLayout offset_bases must not be empty" + for basis in self.offset_bases: + assert len(basis) == rank + for basis in self.block_bases: + assert len(basis) == rank + assert self.alignment > 0 and (self.alignment & (self.alignment - 1)) == 0, \ + "SharedLinearLayout alignment must be a positive power of two" + + def _to_ir(self, builder): + return builder.get_shared_linear_layout(self.offset_bases, self.block_bases, self.alignment) + + def mangle(self) -> str: + return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.alignment}_SharedLinear" + + @property + def shape(self): + rank = len(self.offset_bases[0]) + max_stride = [1] * rank + for b in itertools.chain(self.offset_bases, self.block_bases): + for i, bi in enumerate(b): + max_stride[i] = max(max_stride[i], bi) + return [2 * s for s in max_stride] + + def __hash__(self): + return hash(( + tuple(map(tuple, self.offset_bases)), + tuple(map(tuple, self.block_bases)), + self.alignment, + )) + + +# Python impl of LinearEncodingAttr::basesPerDim +def bases_per_dim(bases, rank, skip_broadcast=True): + result = [1] * rank + + if not bases: + return result + + non_zero_idx = None + + for basis in bases: + # Find the first non-zero index in the current basis + idx = next((i for i, v in enumerate(basis) if v != 0), None) + if idx is not None: + non_zero_idx = idx + result[idx] *= 2 + elif not skip_broadcast: + # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index + assert non_zero_idx is not None + result[non_zero_idx] *= 2 + + return result + + +def warps_per_cta(layout, shape): + if isinstance(layout, DistributedLinearLayout): + return bases_per_dim(layout.warp_bases, len(shape)) + elif isinstance(layout, (SliceLayout, DotOperandLayout)): + return warps_per_cta(layout.parent, shape) + else: + return layout.warps_per_cta diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/_math.py b/third_party/mthreads/python/triton/experimental/gluon/language/_math.py new file mode 100644 index 0000000000..b9c8d7605e --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/_math.py @@ -0,0 +1,20 @@ +import triton.language.math as tl_math +from ._core import builtin + +umulhi = builtin(tl_math.umulhi) +exp = builtin(tl_math.exp) +exp2 = builtin(tl_math.exp2) +fma = builtin(tl_math.fma) +log = builtin(tl_math.log) +log2 = builtin(tl_math.log2) +cos = builtin(tl_math.cos) +rsqrt = builtin(tl_math.rsqrt) +sin = builtin(tl_math.sin) +sqrt = builtin(tl_math.sqrt) +sqrt_rn = builtin(tl_math.sqrt_rn) +abs = builtin(tl_math.abs) +fdiv = builtin(tl_math.fdiv) +div_rn = builtin(tl_math.div_rn) +erf = builtin(tl_math.erf) +floor = builtin(tl_math.floor) +ceil = builtin(tl_math.ceil) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/_semantic.py b/third_party/mthreads/python/triton/experimental/gluon/language/_semantic.py new file mode 100644 index 0000000000..13162a71f4 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/_semantic.py @@ -0,0 +1,607 @@ +from typing import Sequence, List, TypeVar, Tuple, Callable +import math +from triton.language.semantic import TritonSemantic +from . import _core as ttgl +from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout, SharedLinearLayout +from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout +from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values + +TensorTy = TypeVar("TensorTy") + + +def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError): + if not cond: + raise category(msg_fn()) + + +def _is_int_list(value): + return isinstance(value, Sequence) and all(isinstance(i, int) for i in value) + + +def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant, cga_layout=None): + _check(isinstance(instr_variant, str), lambda: "instr_variant must be a string") + _check(instr_variant in ("32x32b", "16x64b", "16x128b", "16x256b", "16x32bx2", "32x32b_splitn"), + lambda: f"unknown instr_variant: {instr_variant}") + _check(isinstance(num_warps, int), lambda: f"num_warps must be an int but got {type(num_warps)!r}") + _check(num_warps >= 4 and (num_warps & (num_warps - 1)) == 0, lambda: "num_warps must be a power of two and >= 4") + + shape = list(shape) + _check(all(isinstance(dim, int) for dim in shape), lambda: f"shape entries must be ints but got {shape}") + rank = len(shape) + _check(rank == 2, lambda: "expected a 2D tensor") + + if cga_layout is None: + cga_layout = [] + splitn = instr_variant == "32x32b_splitn" + atom_variant = "32x32b" if splitn else instr_variant + + if cga_layout: + for basis in cga_layout: + _check(len(basis) == rank, lambda: "cga_layout basis rank mismatch") + + layout_obj = compute_tmem_reg_layout( + element_ty, + shape, + layout, + num_warps, + atom_variant, + cga_layout, + ) + _check(layout_obj is not None, + lambda: f"TMEM layout '{atom_variant}' unsupported for shape {shape} and num_warps {num_warps}") + + if splitn: + N = shape[1] + if not layout_obj.reg_bases: + # We cannot use this layout in a load or a store ATM due to a PTX bug! + # You can work around this by loading to 32x32b and follow by a convert_layout to this layout. + _check(layout_obj.lane_bases[-1] == [0, N // 2], + lambda: f"splitn with 1 register requires the last lane basis to be [0, N / 2]. Got {layout_obj}") + layout_obj.reg_bases.append([0, N // 2]) + layout_obj.lane_bases[-1] = [0, 0] + elif layout_obj.reg_bases[-1] != [0, N // 2]: + bitwidth = element_ty.primitive_bitwidth + num_reg = 2**len(layout_obj.reg_bases) + _check( + num_reg > 32 // bitwidth, lambda: "To be able to `tmem.load` into `tl.split` you need to have more " + f"than {32 // bitwidth} {bitwidth}-bit registers, as you need to use " + "the instruction 32x32b.x1 twice. You can always load into " + "instr_variant=\"32x32b\" and then convert_layout to this layout otherwise.") + + reg_bases = layout_obj.reg_bases + for bases_str in ("lane_bases", "warp_bases"): + bases = getattr(layout_obj, bases_str) + for i, basis in enumerate(bases): + if basis == [0, N // 2]: + reg_bases[-1], bases[i] = bases[i], reg_bases[-1] + return layout_obj + assert False, f"splitn requires at least one basis of [0, N / 2]. Got {layout}" + return layout_obj + + +_compute_tmem_reg_layout.__triton_builtin__ = True + + +class GluonCallerContext: + + def __init__(self, num_warps: int): + self.num_warps = num_warps + + def mangle(self): + return f"_NW{self.num_warps}" + + def initialize_callee(self, fn, builder): + fn.set_attr("ttg.num-warps", builder.get_int32_attr(self.num_warps)) + + +class GluonSemantic(TritonSemantic[TensorTy]): + tensor = ttgl.tensor + lang = ttgl + + builder: GluonOpBuilder + + def __init__(self, builder: GluonOpBuilder): + self.builder = builder + + def _wrap_handle_infer_layout(self, handle, scalar_ty, shape): + if shape == []: + ty = scalar_ty + else: + ty = ttgl.distributed_type(scalar_ty, shape, self.builder.get_gluon_layout_from_tensor(handle)) + return self.tensor(handle, ty) + + def _wrap_tensor_infer_layout(self, tensor): + return self._wrap_handle_infer_layout(tensor.handle, tensor.type.scalar, tensor.shape) + + def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]): + if len(lhs_shape) != len(rhs_shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}") + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + return ret_shape + + def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: + dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if axis < 0: + axis += len(input.shape) + + _check(isinstance(input.type, ttgl.distributed_type), + lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}") + layout = input.type.layout + _check(isinstance(layout, (SliceLayout, AutoLayout, CoalescedLayout)), + lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}") + _check( + isinstance(layout, (AutoLayout, CoalescedLayout)) or layout.dim == axis, + lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}") + + handle = self.builder.create_expand_dims(input.handle, axis) + return self._wrap_handle_infer_layout(handle, input.type.scalar, dst_shape) + + def join(self, a: TensorTy, b: TensorTy) -> TensorTy: + a, b = self.broadcast_impl_value(a, b) + _check(a.shape != [], lambda: "Cannot join scalars in gluon") + value = super().join(a, b) + return self._wrap_tensor_infer_layout(value) + + def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: + lhs, rhs = super().split(a) + return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs) + + def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy: + value = super().permute(input, dims) + return self._wrap_tensor_infer_layout(value) + + def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy: + _check(isinstance(input.type, ttgl.distributed_type), + lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}") + src_shape = input.type.get_block_shapes() + _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout) + handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder)) + return self.tensor(handle, ret_ty) + + def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: + lhs_ty = lhs.type + rhs_ty = rhs.type + + if not lhs_ty.is_block() or not rhs_ty.is_block(): + return super().broadcast_impl_value(lhs, rhs) + + _check(isinstance(lhs_ty, ttgl.distributed_type), + lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}") + _check(isinstance(rhs_ty, ttgl.distributed_type), + lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}") + + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape) + + is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout) + is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout) + if is_lhs_auto and not is_rhs_auto: + lhs = self.set_auto_layout(lhs, rhs_ty.layout) + elif is_rhs_auto and not is_lhs_auto: + rhs = self.set_auto_layout(rhs, lhs_ty.layout) + elif lhs_ty.layout != rhs_ty.layout: + raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}") + + lhs = self.broadcast_impl_shape(lhs, ret_shape) + rhs = self.broadcast_impl_shape(rhs, ret_shape) + return lhs, rhs + + def arange(self, start, end, layout): + shape = [end - start] + if layout is None: + layout = AutoLayout() + ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout) + return super().arange(start, end, ret_ty=ret_ty) + + def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool): + _check(not can_reorder, lambda: "can_reorder is not supported in gluon") + value = super().reshape(input, dst_shape, can_reorder) + return self._wrap_tensor_infer_layout(value) + + def splat(self, value, shape, layout): + if len(shape) == 0: + return value + ret_ty = ttgl.distributed_type(value.dtype, shape, layout) + handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle) + return ttgl.tensor(handle, ret_ty) + + def full(self, shape, value, dtype, layout): + scalar = self.make_scalar(value, dtype) + if layout is None: + layout = AutoLayout() + return self.splat(scalar, shape, layout) + + def convert_layout(self, value, layout, assert_trivial=False): + ty = value.type + _check(isinstance(ty, ttgl.distributed_type), + lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}") + _check(isinstance(layout, ttgl.DistributedLayout), + lambda: f"expected 'layout' to be a DistributedLayout but got {layout}") + ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout) + ret_ty_ir = ret_ty.to_ir(self.builder) + if assert_trivial and not self.builder.is_convert_layout_trivial(ret_ty_ir, value.handle): + raise TypeError(f"layout conversion from {ty.layout} to {layout} is not trivial.\n" + f"The linear layouts are:\n{self.to_linear_layout(ty.layout, ty.shape)}\n" + f"{self.to_linear_layout(layout, ty.shape)}") + handle = self.builder.create_convert_layout(ret_ty_ir, value.handle) + return ttgl.tensor(handle, ret_ty) + + def allocate_shared(self, element_ty, shape, layout, value): + _check(isinstance(element_ty, ttgl.dtype), lambda: f"expected 'element_ty' to be a dtype but got {element_ty}") + _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}") + _check(isinstance(layout, ttgl.SharedLayout), + lambda: f"expected 'layout' to be a SharedLayout but got {layout}") + ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape) + if value is not None: + handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle) + else: + handle = self.builder.create_local_alloc(ty.to_ir(self.builder)) + return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape) + + def shared_load(self, mem_desc, layout): + _check(isinstance(layout, ttgl.DistributedLayout), + lambda: f"expected 'layout' to be a DistributedLayout but got {layout}") + ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout) + handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle) + return ttgl.tensor(handle, ret_ty) + + def shared_store(self, mem_desc, value): + _check(isinstance(value, ttgl.tensor), lambda: f"expected 'value' to be a tensor, but got a {type(value)}") + _check(value.shape == mem_desc.shape, + lambda: f"source shape {value.shape} and destination shape {mem_desc.shape} must match") + _check(value.dtype == mem_desc.dtype, + lambda: f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match") + self.builder.create_local_store(mem_desc.handle, value.handle) + + def shared_gather(self, mem_desc, indices, axis): + _check(isinstance(indices, ttgl.tensor), + lambda: f"expected 'indices' to be a tensor, but got a {type(indices)}") + _check(isinstance(axis, int), lambda: f"expected 'axis' to be an int, but got a {type(axis)}") + _check( + len(indices.shape) == mem_desc.rank, + lambda: f"indices rank must match memdesc rank: got {len(indices.shape)} and {mem_desc.rank}") + _check(0 <= axis < mem_desc.rank, lambda: f"axis {axis} is out of bounds for memdesc rank {mem_desc.rank}") + _check(indices.dtype.is_int(), lambda: f"indices must have integer dtype, got {indices.dtype}") + + ret_ty = ttgl.distributed_type(mem_desc.dtype, indices.shape, indices.type.layout) + handle = self.builder.create_local_gather(ret_ty.to_ir(self.builder), mem_desc.handle, indices.handle, axis) + return ttgl.tensor(handle, ret_ty) + + def shared_scatter(self, mem_desc, values, indices, axis): + _check(isinstance(indices, ttgl.tensor), + lambda: f"expected 'indices' to be a tensor, but got a {type(indices)}") + _check(isinstance(axis, int), lambda: f"expected 'axis' to be an int, but got a {type(axis)}") + _check(isinstance(values, ttgl.tensor), lambda: f"expected 'values' to be a tensor, but got a {type(values)}") + _check( + len(indices.shape) == mem_desc.rank, + lambda: f"indices rank must match memdesc rank: got {len(indices.shape)} and {mem_desc.rank}") + _check(0 <= axis < mem_desc.rank, lambda: f"axis {axis} is out of bounds for memdesc rank {mem_desc.rank}") + _check(indices.dtype.is_int(), lambda: f"indices must have integer dtype, got {indices.dtype}") + _check(values.shape == indices.shape, + lambda: f"values must have the same shape as indices: got {values.shape} and {indices.shape}") + _check(values.type.layout == indices.type.layout, lambda: "values must have the same layout as indices") + _check( + values.dtype == mem_desc.dtype, + lambda: f"values element type must match destination element type: got {values.dtype} and {mem_desc.dtype}") + + self.builder.create_local_scatter(mem_desc.handle, values.handle, indices.handle, axis) + + def bank_conflicts(self, distr_ty, shared_ty): + if not isinstance(distr_ty, ttgl.distributed_type): + raise TypeError( + f"bank_conflicts expects the register layout to be a distributed_type, got {type(distr_ty)}") + + if not isinstance(shared_ty, ttgl.shared_memory_descriptor_type): + raise TypeError( + f"bank_conflicts expects the shared layout to be a shared_memory_descriptor_type, got {type(shared_ty)}" + ) + + if distr_ty.shape != shared_ty.shape: + raise ValueError(f"register shape {distr_ty.shape} and shared shape {shared_ty.shape} must match") + if shared_ty.element_ty != distr_ty.element_ty: + raise ValueError( + f"mismatched dtypes between register ({distr_ty.element_ty}) and shared ({shared_ty.element_ty}) layouts" + ) + if shared_ty.shape != shared_ty.alloc_shape[-len(shared_ty.shape):]: + raise ValueError( + f"bank_conflicts NYI for subslices. Got shape {shared_ty.shape} and alloc_shape {shared_ty.alloc_shape}" + ) + + reg_attr = distr_ty.layout._to_ir(self.builder) + shared_attr = shared_ty.layout._to_ir(self.builder) + return self.builder.get_shared_bank_conflicts(reg_attr, shared_attr, list(distr_ty.shape), + distr_ty.element_ty.primitive_bitwidth) + + def to_linear_layout(self, layout, shape): + from triton.experimental.gluon.language.nvidia.blackwell import ( + TensorMemoryLayout, + TensorMemoryScalesLayout, + ) + _check( + isinstance(layout, (DistributedLayout, SharedLayout, TensorMemoryLayout, TensorMemoryScalesLayout)), lambda: + f"Expected a DistributedLayout, SharedLayout, or TensorMemoryLayout or TensorMemoryScalesLayout, got {type(layout)}" + ) + + if isinstance(layout, (AutoLayout, DistributedLinearLayout, SharedLinearLayout)): + return ttgl.constexpr(layout) + + return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape)) + + def shared_dealloc(self, mem_desc): + self.builder.create_local_dealloc(mem_desc.handle) + + def set_auto_layout(self, value, layout): + src_ty = value.type + _check(isinstance(layout, DistributedLayout), + lambda: f"set_auto_layout must set to a distributed layout but got {layout}") + _check(isinstance(src_ty.layout, AutoLayout), + lambda: f"set_auto_layout input must have auto layout but got {value.type.layout}") + handle = self.builder.create_set_auto_layout(layout._to_ir(self.builder), value.handle) + res_ty = ttgl.distributed_type(src_ty.element_ty, src_ty.shape, layout) + return self.tensor(handle, res_ty) + + def memdesc_slice(self, mem_desc, start, length, dim): + _check(isinstance(start, int), lambda: f"expected 'start' to be an int but got {start}") + _check(isinstance(length, int), lambda: f"expected 'length' to be an int but got {length}") + _check(isinstance(dim, int), lambda: f"expected 'dim' to be an int but got {dim}") + offsets = [0] * mem_desc.rank + offsets[dim] = start + shape = list(mem_desc.shape) + shape[dim] = length + layout = mem_desc.layout + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) + builder = self.builder + handle = builder.create_memdesc_subslice(ty.to_ir(builder), mem_desc.handle, offsets) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def memdesc_index(self, mem_desc, index): + index = self.to_tensor(index) + _check(index.type == ttgl.int32, lambda: f"expected 'index' to be int32 but got {index.type}") + shape = mem_desc.shape[1:] + index = self.to_tensor(index).handle + layout = mem_desc.layout + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, shape) + builder = self.builder + handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def memdesc_trans(self, mem_desc, order): + _check(_is_int_list(order), lambda: f"all elements of 'order' must be integers but got {order}") + _check( + len(order) == len(mem_desc.shape), + lambda: f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match") + + shape = [mem_desc.shape[i] for i in order] + alloc_shape = mem_desc.type.alloc_shape + new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank] + new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order] + + handle = self.builder.create_memdesc_trans(mem_desc.handle, order) + layout = self.builder.get_gluon_layout_from_memdesc(handle) + return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape, + alloc_shape=new_alloc_shape, layout=layout) + + def memdesc_reshape(self, mem_desc, shape): + _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}") + _check( + math.prod(shape) == math.prod(mem_desc.shape), + lambda: (f"memdesc_reshape total elements mismatch: " + f"{mem_desc.shape} -> {shape}"), + ) + + handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape) + layout = self.builder.get_gluon_layout_from_memdesc(handle) + alloc_shape = mem_desc.type.alloc_shape + prefix_len = len(alloc_shape) - mem_desc.rank + new_alloc_shape = alloc_shape[:prefix_len] + list(shape) + + return ttgl.shared_memory_descriptor( + handle, + element_ty=mem_desc.dtype, + shape=shape, + alloc_shape=new_alloc_shape, + layout=layout, + ) + + def memdesc_reinterpret(self, mem_desc, dtype, shape, layout): + _check(isinstance(dtype, ttgl.dtype), lambda: f"expected 'dtype' to be a dtype but got {dtype}") + _check(_is_int_list(shape), lambda: f"all elements of 'shape' must be integers but got {shape}") + _check(isinstance(layout, ttgl.SharedLayout), + lambda: f"expected 'layout' to be a SharedLayout but got {layout}") + ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape) + handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def wrap_tensor(self, x, scalar_ty, ret_shape, layout): + if ret_shape: + res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout) + else: + res_ty = scalar_ty + return self.tensor(x, res_ty) + + @staticmethod + def _check_same_layout(xs): + for x in xs: + _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}") + layouts = [x.type.layout for x in xs] + l0 = layouts[0] + _check(all(l == l0 for l in layouts[1:]), + lambda: f"Expected inputs to have matching layouts, but got: {layouts}") + + def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn, + reverse: bool) -> Tuple[TensorTy, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + assert scan_op.verify() + + return tuple( + self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape) + for i in range(len(inputs))) + + def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]: + if axis is None: + inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=False) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}") + self._check_same_layout(inputs) + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + assert reduce_op.verify() + + return tuple( + self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) + for i in range(len(inputs))) + + def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy: + _check(len(input.shape) == 1, lambda: "histogram only supports 1D input") + _check(input.dtype.is_int(), lambda: "histogram only supports integer input") + _check(layout is not None, lambda: "histogram requires a destination layout") + if mask is not None: + mask, input = self.broadcast_impl_value(mask, input) + _check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type") + mask = mask.handle + layout_attr = layout._to_ir(self.builder) + handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr) + return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout) + + def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool, layout) -> TensorTy: + _check(layout is not None, lambda: "cat requires a destination layout") + _check(can_reorder, lambda: "current implementation of `cat` always may reorder elements") + _check(len(lhs.shape) == 1, lambda: "cat requires a rank-1 input") + ret_type = ttgl.distributed_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]], layout) + return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle, ret_type.to_ir(self.builder)), ret_type) + + def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: + _check(isinstance(src.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {src.type!r}") + _check(isinstance(index.type, ttgl.distributed_type), + lambda: f"expected distributed_type but got: {index.type!r}") + _check(index.type.scalar.is_int(), lambda: f"expected integer scalar type but got: {index.type.scalar!r}") + + rank = len(src.type.shape) + _check(len(index.type.shape) == rank, lambda: "source and index tensors must have the same rank") + _check(-rank <= axis < rank, lambda: f"gather axis {axis} must be < source rank ({rank})") + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + _check( + index.type.shape[d] == src.type.shape[d], + lambda: f"index dim {axis} must match the corresponding source dim", + ) + gather = self.builder.create_gather(src.handle, index.handle, axis) + return self.wrap_tensor(gather, src.type.scalar, index.type.shape, index.type.layout) + + def fp4_to_fp(self, src: TensorTy, elem_type, axis) -> TensorTy: + result = self.builder.create_fp4_to_fp(src.handle, elem_type.to_ir(self.builder), axis) + shape = list(src.type.shape) + shape[axis] *= 2 + return self._wrap_handle_infer_layout(result, elem_type, shape) + + def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], + generator): + for _, args in functions_and_args: + _check(isinstance(args, (tuple, ttgl.tuple)), + lambda: f"function arguments must be a tuple of arguments, but got {type(args)}") + + assert len(functions_and_args) >= 1, "expected at least one function for the default partition" + default_partition, default_args = functions_and_args[0] + num_partitions = len(functions_and_args) - 1 + workers = functions_and_args[1:] + + assert num_partitions == len( + worker_num_warps + ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts" + + if worker_num_regs is not None: + assert num_partitions == len( + worker_num_regs + ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts" + + builder = self.builder + insert_pt = builder.get_insertion_point() + + # Emit the default partition to get the result types. + default_block = builder.new_block() + builder.set_insertion_point_to_start(default_block) + default_result = generator.call_JitFunction(default_partition, default_args, kwargs={}) + mlir_results = flatten_values_to_ir([default_result]) + builder.create_warp_yield(mlir_results) + result_types = [r.get_type() for r in mlir_results] + + # Create the warp specialize op. + worker_args = [flatten_values_to_ir(args) for _, args in workers] + mlir_args = sum(worker_args, []) + builder.restore_insertion_point(insert_pt) + ws_op = builder.create_warp_specialize(result_types, worker_num_warps) + ws_op.get_default_region().push_back(default_block) + + if worker_num_regs is not None: + ws_op.set_requested_registers(worker_num_regs) + + # Emit the partition regions. + builder.create_block_with_parent(ws_op.get_partition_op_holder(), []) + partitions_op = builder.create_warp_specialize_partitions(mlir_args, num_partitions) + arg_types = [arg.get_type() for arg in mlir_args] + arg_it = 0 + for i, (func, args) in enumerate(workers): + caller_context = GluonCallerContext(num_warps=worker_num_warps[i]) + block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types) + mlir_args = worker_args[i] + block_args = [block.get_argument(arg_it + j) for j in range(len(mlir_args))] + block_args = unflatten_ir_values(block_args, [arg.type for arg in args]) + generator.call_JitFunction(func, block_args, kwargs={}, caller_context=caller_context) + builder.create_warp_return() + arg_it += len(mlir_args) + + builder.set_insertion_point_after(ws_op.get_operation()) + mlir_results = [ws_op.get_result(i) for i in range(len(result_types))] + return next(unflatten_ir_values(mlir_results, [default_result.type])) + + def num_ctas(self): + return ttgl.constexpr(self.builder.options.num_ctas) + + def num_warps(self, generator): + if generator.caller_context is not None: + assert isinstance(generator.caller_context, GluonCallerContext) + return ttgl.constexpr(generator.caller_context.num_warps) + return ttgl.constexpr(self.builder.options.num_warps) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/_standard.py b/third_party/mthreads/python/triton/experimental/gluon/language/_standard.py new file mode 100644 index 0000000000..caa0e6fb0f --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/_standard.py @@ -0,0 +1,81 @@ +from typing import TypeVar +from triton.runtime.jit import JITFunction +import triton.language.standard as tl_standard +from .._runtime import GluonJITFunction, jit +from triton import knobs +from . import _core as ttgl + +T = TypeVar("T") + + +def _import_from_triton(fn: JITFunction[T]) -> GluonJITFunction[T]: + assert knobs.runtime.interpret or isinstance(fn, JITFunction) + # Wrap the function and preserve its original docstring + gluon_fn = jit(fn.fn) + gluon_fn.__doc__ = fn.__doc__ + return gluon_fn + + +cdiv = _import_from_triton(tl_standard.cdiv) +sum = _import_from_triton(tl_standard.sum) +max = _import_from_triton(tl_standard.max) +min = _import_from_triton(tl_standard.min) +ravel = _import_from_triton(tl_standard.ravel) +reduce_or = _import_from_triton(tl_standard.reduce_or) +xor_sum = _import_from_triton(tl_standard.xor_sum) + + +@jit +def zeros(shape, dtype, layout=None): + """ + Create a tensor filled with zeros. + + Args: + shape (Sequence[int]): The shape of the tensor. + dtype (dtype): The data type for the tensor. + layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout(). + + Returns: + tensor: A tensor where every element is zero. + """ + return ttgl.full(shape, 0, dtype, layout) + + +@jit +def full_like(input, value, shape=None, dtype=None, layout=None): + """ + Create a tensor with the same properties as a given tensor, filled with a specified value. + + Args: + input (tensor): Reference tensor to infer default shape, dtype, and layout. + value (int or float): The fill value. + shape (Sequence[int], optional): Target shape. Defaults to input.shape. + dtype (dtype, optional): Target data type. Defaults to input.dtype. + layout (DistributedLayout, optional): Target layout. Defaults to input.layout. + + Returns: + tensor: A tensor where every element equals value. + """ + return ttgl.full( + input.shape if shape is None else shape, + value, + input.dtype if dtype is None else dtype, + input.type.layout if layout is None else layout, + ) + + +@jit +def zeros_like(input, shape=None, dtype=None, layout=None): + """ + Create a tensor with the same properties as a given tensor, filled with zeros. + + Args: + input (tensor): Reference tensor to infer default shape, dtype, and layout. + shape (Sequence[int], optional): Target shape. Defaults to input.shape. + dtype (dtype, optional): Target data type. Defaults to input.dtype. + layout (DistributedLayout, optional): Target layout. Defaults to input.layout. + + Returns: + tensor: A tensor where every element is zero. + """ + return full_like(input, 0, shape=shape, dtype=dtype, layout=layout) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/__init__.py new file mode 100644 index 0000000000..eeb2512a8d --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/__init__.py @@ -0,0 +1,8 @@ +from .._core import builtin +from ._layouts import AMDMFMALayout, AMDWMMALayout +from . import cdna3, cdna4 +from . import rdna3, rdna4 +from . import gfx1250 +from .warp_pipeline import warp_pipeline_stage + +__all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4", "rdna3", "rdna4", "gfx1250", "warp_pipeline_stage"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/_layouts.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/_layouts.py new file mode 100644 index 0000000000..1c796de6bc --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/_layouts.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional +from triton.language.core import _unwrap_if_constexpr + +from triton.experimental.gluon.language._layouts import DistributedLayout + +__all__ = [ + "AMDMFMALayout", + "AMDWMMALayout", +] + + +@dataclass(frozen=True) +class AMDMFMALayout(DistributedLayout): + """ + Represents a layout for AMD MFMA (matrix core) operations. + + Args: + version (int): The GPU architecture. + instr_shape (List[int]): The shape in the form of (M, N, K) of the matrix. + transposed (bool): Indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write. + warps_per_cta (List[int]): The warp layout in the block. + element_bitwidth Optional(int): Bit width of the output element type. Supported values are 32 and 64. Defaults to 32. + tiles_per_warp Optional(List[int]): The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. + + Current supported versions: + + - 1: gfx908 + - 2: gfx90a + - 3: gfx942 + - 4: gfx950 + """ + version: int + instr_shape: List[int] + transposed: bool + warps_per_cta: List[int] + element_bitwidth: Optional[int] = None + tiles_per_warp: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + super().__setattr__("version", _unwrap_if_constexpr(self.version)) + super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape)) + super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) + super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) + super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth)) + super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(self.cga_layout)) + + if self.element_bitwidth is None: + super().__setattr__("element_bitwidth", 32) + if self.tiles_per_warp is None: + super().__setattr__("tiles_per_warp", [1] * len(self.warps_per_cta)) + + self.verify() + + def _to_ir(self, builder): + return builder.get_amd_mfma_layout( + self.version, + self.warps_per_cta, + self.instr_shape, + self.transposed, + self.cga_layout, + self.tiles_per_warp, + self.element_bitwidth, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + cga_layout = stringify(["~".join(map(str, vec)) for vec in self.cga_layout] if self.cga_layout else None) + return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{self.element_bitwidth}_{stringify(self.tiles_per_warp)}_{cga_layout}_MFMA" + + def verify(self): + assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range" + assert len(self.instr_shape) == 3, "instr_shape must follow the (M, N, K) format" + valid_shapes = [[32, 32], [16, 16], [64, 4], [4, 64]] + assert self.instr_shape[0:2] in valid_shapes, f"invalid intrinsic shape {self.instr_shape}" + assert self.element_bitwidth in [32, 64], "element bitwidth must be 32 or 64" + + rank = len(self.warps_per_cta) + assert all(len(vec) == rank for vec in self.cga_layout), "cga_layout basis rank mismatch" + + def __hash__(self): + return hash(( + self.version, + tuple(self.instr_shape), + self.transposed, + tuple(self.warps_per_cta), + self.element_bitwidth if self.element_bitwidth else None, + tuple(self.tiles_per_warp) if self.tiles_per_warp else None, + tuple(tuple(vec) for vec in self.cga_layout), + )) + + @property + def rank(self): + return len(self.warps_per_cta) + + +@dataclass(frozen=True) +class AMDWMMALayout(DistributedLayout): + """ + Represents a layout for AMD WMMA (matrix core) operations. + + Args: + version (int): Indicates the GPU architecture. + transposed (bool): Indicates the result tensor is transposed. + warp_bases (List[List[int]]): Warp bases for CTA layout. + reg_bases (Optional[List[List[int]]]): Repetition (register) bases for CTA layout. + instr_shape (Optional[List[int]]): Instruction shape (M, N, K). Defaults to (16, 16, 16). + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. + rank (Optional[int]): rank of warp and register bases. Default to 2 if missing. + + Current supported versions: + + - 1: RDNA3; e.g., gfx1100, gfx1101 + - 2: RDNA4; e.g., gfx1200, gfx1201 + - 3: gfx1250 + """ + version: int + transposed: bool + warp_bases: List[List[int]] + reg_bases: Optional[List[List[int]]] = None + instr_shape: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) + rank: Optional[int] = None + + def __post_init__(self): + super().__setattr__("version", _unwrap_if_constexpr(self.version)) + super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) + super().__setattr__("warp_bases", [list(inner) for inner in _unwrap_if_constexpr(self.warp_bases)]) + super().__setattr__("reg_bases", + [list(inner) + for inner in _unwrap_if_constexpr(self.reg_bases)] if self.reg_bases is not None else []) + instr_shape = _unwrap_if_constexpr(self.instr_shape) if self.instr_shape is not None else [16, 16, 16] + super().__setattr__("instr_shape", _unwrap_if_constexpr(instr_shape)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(self.cga_layout)) + rank = _unwrap_if_constexpr(self.rank) if self.rank is not None else 2 + super().__setattr__("rank", rank) + self.verify() + + def _to_ir(self, builder): + return builder.get_amd_wmma_layout( + self.version, + self.transposed, + self.warp_bases, + self.reg_bases, + self.cga_layout, + self.instr_shape, + self.rank, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + def nested_stringify(x): + return stringify(["~".join(map(str, vec)) for vec in x] if x else None) + + warp_bases = nested_stringify(self.warp_bases) + reg_bases = nested_stringify(self.reg_bases) + cga_layout = nested_stringify(self.cga_layout) + return f"WMMA_{self.version}_{self.transposed}_{warp_bases}_{reg_bases}_{stringify(self.instr_shape)}_{cga_layout}_{self.rank}_WMMA" + + def verify(self): + assert self.version >= 1 and self.version <= 3, "version must be in the [1, 3] range" + if len(self.warp_bases) > 0: + assert len(self.warp_bases[0]) == self.rank, "warp_bases basis rank mismatch" + assert all(len(vec) == self.rank for vec in self.cga_layout), "cga_layout basis rank mismatch" + + def __hash__(self): + return hash(( + self.version, + self.transposed, + tuple(tuple(vec) for vec in self.warp_bases), + tuple(tuple(vec) for vec in self.reg_bases), + tuple(self.instr_shape) if self.instr_shape else None, + tuple(tuple(vec) for vec in self.cga_layout), + self.rank, + )) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/_ops.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/_ops.py new file mode 100644 index 0000000000..547761307d --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/_ops.py @@ -0,0 +1,77 @@ +import math + +from triton import knobs +from triton.experimental.gluon.language import _core as ttgl +from triton.experimental.gluon.language._semantic import _check + +from .._core import _unwrap_if_constexpr +from .._layouts import DotOperandLayout +from ._layouts import AMDWMMALayout + + +def _verify_wmma(version, a, b, acc): + _check(acc is not None, lambda: "acc is required") + + layout = acc.type.layout + _check( + isinstance(layout, AMDWMMALayout) and layout.version == version, + lambda: f"Expected layout to be an instance of AMDWMMALayout with version {version}") + + a_layout = a.type.layout + _check( + isinstance(a_layout, DotOperandLayout) and isinstance(a_layout.parent, AMDWMMALayout) + and a_layout.parent.version == version, + lambda: "Expected a's layout to be a DotOperandLayout with parent matching AMDWMMALayout") + + b_layout = b.type.layout + _check( + isinstance(b_layout, DotOperandLayout) and isinstance(b_layout.parent, AMDWMMALayout) + and b_layout.parent.version == version, + lambda: "Expected b's layout to be a DotOperandLayout with parent matching AMDWMMALayout") + + +def _wmma(version, a, b, acc, semantic): + """ Shared implementation for AMD WMMA operations for Gluon builtins """ + _verify_wmma(version, a, b, acc) + + handle = semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None, + out_dtype=acc.dtype).handle + return ttgl.tensor(handle, acc.type) + + +def _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, scale_fn, semantic): + """ Shared implementation for AMD WMMA scaled and MFMA scaled operation. """ + + def _get_scale_shape(op_idx, operand, format): + operand_shape = [s for s in operand.type.shape] + scale_shape = operand_shape + unpack_factor = 2 if format.value == "e2m1" else 1 + if op_idx == 0: + k = scale_shape[-1] * unpack_factor + scale_shape[-1] = k // 32 + else: + k = scale_shape[-2] * unpack_factor + scale_shape[-2] = k // 32 + scale_shape[-2], scale_shape[-1] = scale_shape[-1], scale_shape[-2] + return scale_shape + + def _create_and_broadcast_default_scale(op_idx, scale, format): + operand = a if op_idx == 0 else b + + scale_shape = _get_scale_shape(op_idx, operand, format) + if isinstance(scale, ttgl.tensor) and scale.numel.value != 1: + # In the case of scale pre-shuffling, the input shape is different from the default shape. We only check + # the number of elements here. + assert math.prod(scale_shape) == scale.numel.value, "Incompatible scale shape" + return scale + + scale_layout = scale_fn(operand.type.layout, scale_shape) + scale_value = _unwrap_if_constexpr(scale) + scale_value = 0x7F if scale_value is None else scale_value + return semantic.full(scale_shape, scale_value, ttgl.uint8, scale_layout) + + a_scale = _create_and_broadcast_default_scale(0, a_scale, a_format) + b_scale = _create_and_broadcast_default_scale(1, b_scale, b_format) + output = semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, fast_math=False, lhs_k_pack=True, + rhs_k_pack=True, out_dtype=ttgl.float32) + return ttgl.tensor(output.handle, acc.type) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna3/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna3/__init__.py new file mode 100644 index 0000000000..7d88a62b84 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna3/__init__.py @@ -0,0 +1,238 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from triton import knobs +from triton.experimental.gluon.language import _core as ttgl +from triton._C.libtriton import ir +from ..._core import builtin, _unwrap_if_constexpr + +if TYPE_CHECKING: + from ..._semantic import GluonSemantic + +__all__ = [ + "buffer_atomic_add", "buffer_atomic_and", "buffer_atomic_min", "buffer_atomic_max", "buffer_atomic_or", + "buffer_atomic_xor", "buffer_atomic_xor", "buffer_load", "buffer_store", "mfma" +] + +_atomic_op_str_to_op = { + "smax": ir.ATOMIC_OP.MAX, "smin": ir.ATOMIC_OP.MIN, "umax": ir.ATOMIC_OP.UMAX, "umin": ir.ATOMIC_OP.UMIN, "fadd": + ir.ATOMIC_OP.FADD, "iadd": ir.ATOMIC_OP.ADD, "and": ir.ATOMIC_OP.AND, "or": ir.ATOMIC_OP.OR, "xor": + ir.ATOMIC_OP.XOR, "xchg": ir.ATOMIC_OP.XCHG +} + + +def _verify_buffer_ops(ptr, offsets, mask=None, other=None): + assert ptr.type.is_ptr(), "ptr must be a scalar pointer type" + + assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type" + assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32" + + if other is not None: + assert mask is not None, "when other is not None, mask should not be None" + + +def _verify_element_type_and_dispatch_op(op, elem_type, arch): + supported_types = [ + ttgl.float16, ttgl.float32, ttgl.bfloat16, ttgl.float64, ttgl.int32, ttgl.int64, ttgl.uint32, ttgl.uint64 + ] + assert elem_type in supported_types, f"{elem_type} is not supported in buffer atomic on {arch}." + + if op in ['and', 'or', 'xor', 'xchg']: + assert elem_type in [ttgl.int32, ttgl.int64], f"{op} with {elem_type} is not supported on CDNA3 or CDNA4" + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + + if op in ['max', 'min']: + if elem_type in [ttgl.int32, ttgl.int64, ttgl.float64]: + op = 's' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type in [ttgl.uint32, ttgl.uint64]: + op = 'u' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + else: + raise ValueError(f"{op} with {elem_type} is not supported on CDNA3 and CDNA4") + + if op == 'add': + if elem_type in [ttgl.uint32, ttgl.uint64]: + op = 'i' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type in [ttgl.float16, ttgl.float32, ttgl.float64]: + op = 'f' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + elif elem_type is ttgl.bfloat16: + assert arch == "cdna4", "Buffer atomic fadd with bf16 is only supported on CDNA4 for now." + op = 'f' + op + return _atomic_op_str_to_op[_unwrap_if_constexpr(op)] + else: + raise ValueError(f"{op} with {elem_type} is not supported on CDNA3 and CDNA4") + + raise ValueError(f"Unknown {op} on CDNA3 or CDNA4") + + +def _buffer_atomic_rmw_impl(op, ptr, offsets, value, arch, mask, sem, scope, _semantic): + _verify_buffer_ops(ptr, offsets, mask) + + op = _verify_element_type_and_dispatch_op(op, ptr.type.scalar.element_ty, arch) + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + mask = _semantic.cast(mask, ttgl.int1) + _, mask = _semantic.broadcast_impl_value(offsets, mask) + mask = mask.handle if mask is not None else ir.value() + + value = _unwrap_if_constexpr(value) + value = _semantic.to_tensor(value) + _, value = _semantic.broadcast_impl_value(offsets, value) + + sem = _semantic._str_to_sem(sem) + scope = _semantic._str_to_scope(scope) + return _semantic.tensor( + _semantic.builder.create_buffer_atomic_rmw(op, ptr.handle, offsets.handle, value.handle, sem, scope, mask), + value.type) + + +@builtin +def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None): + """ + AMD buffer load from global memory via a scalar base pointer and a tensor of + offsets instead of a tensor of pointers. This operation will load data + directly into registers. + + Args: + ptr (pointer to scalar): Global memory scalar base pointer to load from. + offsets (tensor): Offsets tensor for the load operation. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _verify_buffer_ops(ptr, offsets, mask, other) + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + offsets, mask = _semantic.broadcast_impl_value(offsets, mask) + + other = _unwrap_if_constexpr(other) + if other is not None: + other = _semantic.to_tensor(other) + other = _semantic.cast(other, ptr.dtype.element_ty) + offsets, other = _semantic.broadcast_impl_value(offsets, other) + + other = other.handle if other is not None else ir.value() + mask = mask.handle if mask is not None else ir.value() + cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE + + ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty) + builder = _semantic.builder + handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier) + return ttgl.tensor(handle, ret_ty) + + +@builtin +def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: GluonSemantic = None): + """ + AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of + offsets instead of a tensor of pointers. + Args: + stored_value (tensor to be stored): The tensor to be stored to global memory. + ptr (pointer to scalar): Global memory scalar base pointer to store to. + offsets (tensor): Offsets tensor for the store operation. + mask (tensor, optional): Mask tensor for predicated store. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _verify_buffer_ops(ptr, offsets, mask) + + if mask is not None: + offsets, mask = _semantic.broadcast_impl_value(offsets, mask) + + mask = mask.handle if mask is not None else ir.value() + cache_modifier = _semantic._str_to_store_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE + + _semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier) + + +@builtin +def mfma(a, b, acc, _semantic: GluonSemantic = None): + """ + Computes matrix-multiplication of a * b + acc using AMD native matrix core units. + Args: + a (tensor): The first operand of mfma. + b (tensor): The second operand of mfma. + acc (tensor): The accumulator tensor. + """ + assert acc is not None, "acc is required" + ret_type = acc.type + acc = ttgl._unwrap_if_constexpr(acc) + + handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None, + out_dtype=acc.dtype).handle + return ttgl.tensor(handle, ret_type) + + +""" +AMD Buffer Atomic RMW operations. +The supported operatios are max, min, add, and, or, xor, xchg. +Similar to normal atomic ops: it loads data at ptr plus offsets, do `op` with `value`, and store result to `ptr` plus `offsets` with +the specified memory semantics and scope. + +Buffer atomics access global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. +Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with +the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). + +Buffer Atomic RMW ops return the pre-op value in the global memory. + +Args: + ptr (pointer to scalar): Global memory scalar base pointer to load from. + offsets (tensor): Offsets tensor for the load operation. + value (tensor): Another operand of `op`. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + sem (str, optional): Memory Semantic Descriptor. Default is None which means acq_rel memory semantic. + scope (str, optional): Memory Sync Scope for atomic accesses. Default is None and it will be mapped to `gpu`, which is called `agent` for AMDGPU. Please ref https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 for details. +""" + + +@builtin +def buffer_atomic_max(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + return _buffer_atomic_rmw_impl('max', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_min(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('min', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_add(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('add', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_and(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('and', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_or(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('or', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xor(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xor', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/__init__.py new file mode 100644 index 0000000000..4ba53d2ed0 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/__init__.py @@ -0,0 +1,130 @@ +from triton.runtime.jit import constexpr_function +from triton._C.libtriton.gluon_ir import get_amd_mfma_scale_layout as _get_mfma_scale_layout + +from ..._core import builtin +from ..._layouts import DotOperandLayout +from .._layouts import AMDMFMALayout +from .._ops import _mma_scaled +from ..cdna3 import _buffer_atomic_rmw_impl +from ..cdna3 import * # NOQA: F403 +from ..cdna3 import __all__ as __cdna3_all +from . import async_copy + +__all__ = [*__cdna3_all, "async_copy", "mfma_scaled", "get_mfma_scale_layout"] + + +@builtin +def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None): + """ + AMD Scaled MFMA operation. + + ``` + c = a * a_scale @ b * b_scale + acc + ``` + + `a` and `b` use microscaling formats described in + "OCP Microscaling Formats (MX) Specification": + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf. + Currently supported only on CDNA4 hardware. + + Args: + a (tensor): The operand A to be multiplied. + a_scale (Optional[tensor]): Scale factor for operand A. + a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`. + b (tensor): The operand B to be multiplied. + b_scale (Optional[tensor]): Scale factor for operand B. + b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2`. + acc (tensor): Accumulator tensor. + """ + layout = acc.type.layout + assert isinstance(layout, AMDMFMALayout), "Expected layout to be an instance of AMDMFMALayout" + assert (isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent== layout), \ + "Expected lhs layout to be a DotOperandLayout with parent matching MFMA layout" + assert (isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout), \ + "Expected rhs layout to be a DotOperandLayout with parent matching MFMA layout" + + assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}" + assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}" + + return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, get_mfma_scale_layout, _semantic) + + +def _get_mfma_scale_layout_impl(*args, **kwargs): + return _get_mfma_scale_layout(*args, **kwargs) + + +_get_mfma_scale_layout_impl.__triton_builtin__ = True + + +@constexpr_function +def get_mfma_scale_layout(dot_operand_layout, shape): + """ Get the scale layout for MFMA scaled operands. + + Args: + dot_operand_layout (DotOperandLayout): The dot operand layout. + shape (List[int]): The shape of the scale tensor. + + Return: + layout (DistributedLinearLayout): The scale layout. + """ + op_idx = dot_operand_layout.operand_index + parent = dot_operand_layout.parent + assert isinstance(parent, AMDMFMALayout), "Expected parent to be an instance of AMDMFMALayout" + mdim = parent.instr_shape[0] + tiles_per_warp = parent.tiles_per_warp + warps_per_cta = parent.warps_per_cta + return _get_mfma_scale_layout_impl(op_idx, shape, mdim, tiles_per_warp, warps_per_cta) + + +""" +buffer_atomic_rmw of cnda4 shares the same signature and functionalities as cdna3.buffer_atomic_rmw. +The cdna4 version additionally supports `fadd` with `bf16`. +""" + + +@builtin +def buffer_atomic_max(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + return _buffer_atomic_rmw_impl('max', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_min(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('min', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_add(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('add', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_and(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('and', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_or(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('or', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xor(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xor', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) + + +@builtin +def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _semantic=None): + + return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope, + _semantic=_semantic) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py new file mode 100644 index 0000000000..009707c779 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py @@ -0,0 +1,170 @@ +from ..._core import ir, builtin, _unwrap_if_constexpr +from ..._semantic import _check +from ..._layouts import BlockedLayout, SliceLayout +from ..cdna3 import _verify_buffer_ops + +__all__ = [ + "global_load_to_shared", + "buffer_load_to_shared", + "commit_group", + "wait_group", + "load_shared_relaxed", +] + + +@builtin +def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _semantic=None): + """ + AMD global load to shared operation. This operation loads data directly + from global memory to shared memory without going through registers. It + happens asynchronously and requires a subsequent `async_wait` to ensure the + data is available in shared memory. Note that this operation does still + complete in order with ttgl.loads/stores or buffer_loads/stores on CDNA4, + so interleaving with them will hurt performance. + + Compared to `buffer_load_to_shared`, it requires a tensor pointer which + supports 64-bit indexing range for each thread in a block, which gives more + flexibility, but at the cost of higher register pressure and no hardware + out-of-bound masking support. Prefer to use `buffer_load_to_shared` when + possible for better performance. + + The underlying hardware instruction uses separate registers for global + memory address for each thread but the same register for local memory + address for the whole warp. Therefore, while using this operation + the following conditions must be met or lowering to LLVM will fail: + + - For the `ptr` layout, size per thread * bits per element must be 128 or 32. + To get ideal performance, it is recommended to use 128 bits per element. + - Writes to `dest` must be coalesced. + - If `dest` is swizzled, it only can be swizzled within warp boundary. + + Args: + dest (shared_memory_descriptor): Destination shared memory descriptor. + ptr (pointer tensor): Tensor of pointers to global memory to load from. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _check(ptr.type.is_block(), lambda: "expected ptr to be a tensor") + _check(isinstance(ptr.type.layout, (BlockedLayout, SliceLayout)), + lambda: "expected ptr type layout to be BlockedLayout or SliceLayout") + _check( + dest.shape == ptr.shape, lambda: + f"expected dest shape to match pointer shape but got dest.shape = {dest.shape}, pointer.shape = {ptr.shape}") + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + ptr, mask = _semantic.broadcast_impl_value(ptr, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + other = _semantic.to_tensor(other) + other = _semantic.cast(other, ptr.dtype.element_ty) + ptr, other = _semantic.broadcast_impl_value(ptr, other) + + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + mask_handle = mask.handle if mask is not None else ir.value() + other_handle = other.handle if other is not None else ir.value() + _semantic.builder.create_async_copy_global_to_local(dest.handle, ptr.handle, mask_handle, other_handle, + cache_modifier, ir.EVICTION_POLICY.NORMAL, False) + + +@builtin +def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modifier="", _semantic=None): + """ + AMD buffer load to shared operation. Buffer load is similar to global load + but it accesses global memory via a scalar base pointer and a tensor of + 32-bit offsets instead of a tensor of pointers. This operation loads data + directly from global memory to shared memory without going through + registers. It happens asynchronously and requires a subsequent `async_wait` + to ensure thedata is available in shared memory. Note that this operation + does still complete in order with ttgl.loads/stores or buffer_loads/stores + on CDNA4, so interleaving with them will hurt performance. + + Compared to `global_load_to_shared`, it has better performance and also + supports hardware out-of-bound masking. But it strictly requires a + 32-bit offset instead of a 64-bit tensor pointer. + + The underlying hardware instruction uses separate registers for global + memory address for each thread but the same register for local memory + address for the whole warp. Therefore, while using this operation + the following conditions must be met or lowering to LLVM will fail: + + - For the `offsets` layout, size per thread * bits per element must be 128 or 32. + To get ideal performance, it is recommended to use 128 bits per element. + - Writes to `dest` must be coalesced. + - If `dest` is swizzled, it only can be swizzled within warp boundary. + + Args: + dest (shared_memory_descriptor): Destination shared memory descriptor. + ptr (pointer to scalar): Global memory scalar base pointer to load from. + offsets (tensor): Offsets tensor for the load operation. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _check(isinstance(offsets.type.layout, (BlockedLayout, SliceLayout)), + lambda: "expected offsets type layout to be BlockedLayout or SliceLayout") + _verify_buffer_ops(ptr, offsets, mask, other) + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + offsets, mask = _semantic.broadcast_impl_value(offsets, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + other = _semantic.to_tensor(other) + other = _semantic.cast(other, ptr.type.scalar.element_ty) + offsets, other = _semantic.broadcast_impl_value(offsets, other) + + mask = mask.handle if mask is not None else ir.value() + other = other.handle if other is not None else ir.value() + stride = ir.value() + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + + _semantic.builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride, + cache_modifier) + + +@builtin +def commit_group(_semantic=None): + """ + Commit oustanding async operations. + + This finalizes a set of async copy operations which can be waited upon via `wait_group`. + """ + _semantic.builder.create_async_commit_group() + + +@builtin +def wait_group(num_outstanding=0, _semantic=None): + """ + Wait for outstanding commit groups. It will block until the number of + outstanding commit groups is less than or equal to `num_outstanding`. Note that uncommited + async operations will be waited upon even if `num_outstanding` is 0. + + Args: + num_outstanding (int): The number of outstanding commit groups to wait for. Defaults to 0. + """ + num_outstanding = _unwrap_if_constexpr(num_outstanding) + _semantic.builder.create_async_wait_group(num_outstanding) + + +@builtin +def load_shared_relaxed(smem, layout, _semantic=None): + """ + Load a tensor from shared memory with extra hints for the underlying + compiler to avoid emitting unnecessary waits before loading from the target + shared memory. + + Args: + smem (shared_memory_descriptor): Shared memory descriptor to load from. + layout (DistributedLayout): The destination layout of the tensor. + + Returns: + tensor: A Gluon tensor containing the loaded data. + """ + SYNCED_VIA_WAIT_ATTR_NAME = "ttg.amdg.syncedViaAsyncWait" + + layout = _unwrap_if_constexpr(layout) + ret = _semantic.shared_load(smem, layout) + ret.handle.set_attr(SYNCED_VIA_WAIT_ATTR_NAME, _semantic.builder.get_bool_attr(True)) + return ret diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py new file mode 100644 index 0000000000..1c98502a89 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py @@ -0,0 +1,98 @@ +from triton.runtime.jit import constexpr_function +from triton._C.libtriton.gluon_ir import get_amd_wmma_scale_layout as _get_wmma_scale_layout + +from ..._core import builtin +from .._ops import _wmma, _verify_wmma, _mma_scaled +from .._layouts import AMDWMMALayout +from ..cdna3 import buffer_load, buffer_store +from . import tdm +from . import async_copy +from . import mbarrier +from . import cluster + +__all__ = [ + "async_copy", "tdm", "mbarrier", "cluster", "wmma", "wmma_scaled", "buffer_load", "buffer_store", + "get_wmma_scale_layout" +] + + +@builtin +def wmma(a, b, acc, _semantic=None): + """ + Computes matrix-multiplication of a * b + acc using AMD WMMA instruction. + + Args: + a (tensor): The operand a to be multiplied. + b (tensor): The operand b to be multiplied. + acc (tensor): The accumulator tensor. + """ + return _wmma(3, a, b, acc, _semantic) + + +@builtin +def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None): + """ + AMD Scaled WMMA operation. + + ``` + c = a * a_scale @ b * b_scale + acc + ``` + + `a` and `b` use microscaling formats described in + "OCP Microscaling Formats (MX) Specification": + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf. + + Args: + a (tensor): The operand A to be multiplied. + a_scale (Optional[tensor]): Scale factor for operand A. + a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`. + b (tensor): The operand B to be multiplied. + b_scale (Optional[tensor]): Scale factor for operand B. + b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2`. + acc (tensor): Accumulator tensor. + """ + _verify_wmma(3, a, b, acc) + if a_format.value == "e2m1": + wmma_layout = a.type.layout.parent + assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == [16, 16, 64], \ + "e2m1 format expects instr_shape to be [16, 16, 64]" + if b_format.value == "e2m1": + wmma_layout = b.type.layout.parent + assert isinstance(wmma_layout, AMDWMMALayout) and wmma_layout.instr_shape == [16, 16, 64], \ + "e2m1 format expects instr_shape to be [16, 16, 64]" + + acc_layout = acc.type.layout + assert isinstance(acc_layout, AMDWMMALayout) and acc_layout.instr_shape == [16, 16, 128], \ + "accumulator tensor's layout must be [16, 16, 128]" + + assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}" + assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}" + + return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, get_wmma_scale_layout, _semantic) + + +def _get_wmma_scale_layout_impl(*args, **kwargs): + return _get_wmma_scale_layout(*args, **kwargs) + + +_get_wmma_scale_layout_impl.__triton_builtin__ = True + + +@constexpr_function +def get_wmma_scale_layout(dot_operand_layout, shape): + """ Get the scale layout for WMMA scaled operands. + + Args: + dot_operand_layout (DotOperandLayout): The dot operand layout. + shape (List[int]): The shape of the scale tensor. + + Return: + layout (DistributedLinearLayout): The scale layout. + """ + op_idx = dot_operand_layout.operand_index + parent = dot_operand_layout.parent + assert isinstance(parent, AMDWMMALayout), "Expected parent to be an instance of AMDMFMALayout" + mdim = parent.instr_shape[0] + reg_bases = parent.reg_bases + warp_bases = parent.warp_bases + return _get_wmma_scale_layout_impl(op_idx, shape, mdim, reg_bases, warp_bases) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py new file mode 100644 index 0000000000..f22d45191b --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py @@ -0,0 +1,78 @@ +from ..._core import ir, builtin, _unwrap_if_constexpr +from ..._semantic import _check +from triton.experimental.gluon.language._layouts import DistributedLayout +from ..cdna4.async_copy import commit_group, wait_group + +__all__ = ["global_to_shared", "shared_to_global", "commit_group", "wait_group", "mbarrier_arrive"] + + +@builtin +def global_to_shared(smem, pointer, mask=None, other=None, cache_modifier="", _semantic=None): + """ + Asynchronously copy elements from global memory to shared memory. Requires manual syncronization via `wait_group` before accessing the loaded data. + + Args: + smem (shared_memory_descriptor): Destination shared memory descriptor. + pointer (tensor): Source pointer tensor. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None(0). + cache_modifier (str): Cache modifier specifier. Defaults to "". + eviction_policy (str): Eviction policy specifier. Defaults to "". + """ + _check(pointer.type.is_block(), lambda: "expected ptr to be a tensor") + _check(isinstance(pointer.type.layout, DistributedLayout), + lambda: "expected ptr type layout to be BlockedLayout or SliceLayout") + _check( + smem.shape == pointer.shape, lambda: + f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}" + ) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + pointer, mask = _semantic.broadcast_impl_value(pointer, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + other = _semantic.to_tensor(other) + other = _semantic.cast(other, pointer.dtype.element_ty) + pointer, other = _semantic.broadcast_impl_value(pointer, other) + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + mask_handle = mask.handle if mask is not None else ir.value() + other_handle = other.handle if other is not None else ir.value() + _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, other_handle, + cache_modifier, ir.EVICTION_POLICY.NORMAL, False) + + +@builtin +def shared_to_global(pointer, smem, mask=None, cache_modifier="", _semantic=None): + """ + Asynchronously copy elements from shared memory to global memory. Requires manual syncronization via `wait_group` before accessing the stored data. + + Args: + pointer (tensor): Destination pointer tensor. + smem (shared_memory_descriptor): Source shared memory descriptor. + mask (tensor, optional): Mask tensor for predicated stores. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _check(pointer.type.is_block(), lambda: "expected ptr to be a tensor") + _check(isinstance(pointer.type.layout, DistributedLayout), + lambda: "expected ptr type layout to be BlockedLayout or SliceLayout") + _check( + smem.shape == pointer.shape, lambda: + f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}" + ) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + pointer, mask = _semantic.broadcast_impl_value(pointer, mask) + cache_modifier = _semantic._str_to_store_cache_modifier(cache_modifier) + mask_handle = mask.handle if mask is not None else ir.value() + _semantic.builder.create_async_copy_local_to_global(smem.handle, pointer.handle, mask_handle, cache_modifier, + ir.EVICTION_POLICY.NORMAL) + + +@builtin +def mbarrier_arrive(mbarrier, _semantic=None): + """ + Arrive on the mbarrier once all outstanding async copies are complete. + Args: + mbarrier (shared_memory_descriptor): Barrier object to arrive on. + """ + _semantic.builder.create_async_copy_lds_barrier_arrive(mbarrier.handle) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/cluster.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/cluster.py new file mode 100644 index 0000000000..5bef4dd8b5 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/cluster.py @@ -0,0 +1,21 @@ +from triton.experimental.gluon.language._core import builtin + +__all__ = ["arrive", "wait"] + + +@builtin +def arrive(_semantic=None): + """ + Signals that the cluster has arrived at a cluster barrier, used to synchronize execution of CTAs within the same cluster. + """ + _semantic.builder.create_amd_cluster_arrive() + + +@builtin +def wait(_semantic=None): + """ + Wait on a cluster barrier to be arrived by all CTAs within the same cluster. + Arrive and wait operations must come in pairs. Waiting before arriving or arriving more than once + without a corresponding wait will result in undefined behavior. + """ + _semantic.builder.create_amd_cluster_wait() diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py new file mode 100644 index 0000000000..9e4a8dc8bf --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py @@ -0,0 +1,67 @@ +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._layouts import SwizzledSharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +__all__ = ["MBarrierLayout", "init", "wait", "arrive"] + + +class MBarrierLayout(SwizzledSharedLayout): + """ + Layout for mbarrier synchronization. + + Args: + cga_layout (List[List[int]]): CGA layout bases. Defaults to []. + """ + + def __init__(self, cga_layout=None): + super().__init__(vec=1, per_phase=1, max_phase=1, order=[0], cga_layout=cga_layout or []) + + +@builtin +def init(mbarrier, count, _semantic=None): + """ + Initialize an mbarrier with a specified count. An mbarrier consists of an init count, a pending count and a phase. + At initialization, the init count and pending count are initialized with the given 'count' and the phase is initialized to 0. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to initialize. + count (int): The initial count for the barrier. Must be a positive integer. + """ + count = _unwrap_if_constexpr(count) + _semantic.builder.create_lds_barrier_init(mbarrier.handle, count) + + +@builtin +def wait(mbarrier, phase, _semantic=None): + """ + Wait until the mbarrier's phase differs from the provided phase value. + This means that the given 'phase' has completed. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to wait on. + phase (int): The phase value to compare against. The wait completes when + the barrier's phase becomes different from this value. + """ + phase = _semantic.to_tensor(phase) + + _semantic.builder.create_lds_barrier_wait(mbarrier.handle, phase.handle) + + +@builtin +def arrive(mbarrier, *, count=1, _semantic=None): + """ + Arrive at an mbarrier with a specified count. The operation requires a `count` attribute + of at least 1, and decreases the pending arrival count of the mbarrier by the specific count. + If the pending count reaches zero, the phase changes (is decremented in a wraparound manner) and the + pending count is reloaded with the init count value. Returns the mbarrier's phase parity (0 for even, 1 for odd) prior to the "arrive" operation. + + Args: + mbarrier (shared_memory_descriptor): Barrier to be signalled. + count (int): Count to arrive with. Defaults to 1. + + Returns: + prior phase (int): phase of mbarrier, prior to "arrive" operation. + """ + count = _unwrap_if_constexpr(count) + handle = _semantic.builder.create_lds_barrier_arrive(mbarrier.handle, count) + return ttgl.tensor(handle, ttgl.int32) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py new file mode 100644 index 0000000000..f5a0749f65 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py @@ -0,0 +1,250 @@ +from __future__ import annotations +from typing import List, Tuple, TYPE_CHECKING +from dataclasses import dataclass + +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +if TYPE_CHECKING: + from triton._C import ir + from triton.experimental.gluon.language._core import shared_memory_descriptor + +__all__ = [ + "async_load", "async_wait", "make_tensor_descriptor", "tensor_descriptor", "tensor_descriptor_type", "prefetch", + "async_scatter" +] + + +@dataclass(eq=True) +class tensor_descriptor_type(ttgl.base_type): + """The type for a tensor descriptor.""" + + block_type: ttgl.block_type + shape_type: ttgl.tuple_type + strides_type: ttgl.tuple_type + layout: PaddedSharedLayout | SwizzledSharedLayout + + def __str__(self) -> str: + return f"tensor_descriptor<{self.block_type}, {self.layout}>" + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + value = tensor_descriptor(handle, shape, strides, self) + return value, cursor + + def _to_ir(self, builder: ir.builder) -> ir.type: + is_signed = self.block_type.element_ty.is_int_signed() + return builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self._to_ir(builder)) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}_{self.shape_type.mangle()}_{self.strides_type.mangle()}_{self.layout.mangle()}TD" + + +@dataclass +class tensor_descriptor(ttgl.base_value): + """A descriptor representing a tensor in global memory.""" + + handle: ir.value + shape: ttgl.tuple + strides: ttgl.tuple + type: tensor_descriptor_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + @property + def layout(self): + return self.type.layout + + +@builtin +def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor], + strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr], + layout: PaddedSharedLayout | SwizzledSharedLayout, _semantic=None) -> tensor_descriptor: + """Make a tensor descriptor object. + + Args: + base (tensor): base pointer of the tensor in global memory. + shape (List[int]): shape of the tensor. + strides (List[int]): strides of the tensor. + block_shape (List[int]): block shape of the tensor. + layout (PaddedSharedLayout | SwizzledSharedLayout): the layout of the tensor in shared memory. + + Returns: + tensor_descriptor: the created tensor descriptor object + """ + ndim = len(shape) + assert 1 <= ndim <= 5, f"Expected 1 <= ndim <= 5 but got {ndim} dimensions" + assert len(strides) == ndim, f"Expected {ndim} strides but got {len(strides)}" + assert len(block_shape) == ndim, f"Expected block_shape to have {ndim} dimensions but got {len(strides)}" + assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer" + + layout = _unwrap_if_constexpr(layout) + assert isinstance(layout, (PaddedSharedLayout, SwizzledSharedLayout)), \ + "Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout" + if isinstance(layout, SwizzledSharedLayout): + assert layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout" + + base_handle = base.handle + shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape + stride_handles = _semantic._convert_to_ir_values(strides, require_i64=True) # i64 stride + + shape = ttgl.tuple(shape) + strides = ttgl.tuple(strides) + block_type = ttgl.block_type(base.type.element_ty, block_shape) + type = tensor_descriptor_type(block_type, shape.type, strides.type, layout) + + padding = _semantic._str_to_padding_option("zero") + handle = _semantic.builder.create_make_tensor_descriptor(type._to_ir(_semantic.builder), base_handle, shape_handles, + stride_handles, padding) + + return tensor_descriptor(handle, shape, strides, type) + + +@builtin +def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], dest: shared_memory_descriptor, + pred=1, mbarrier: shared_memory_descriptor = None, _semantic=None) -> None: + """Load a block of tensor specified in tensor descriptor from global memory to shared memory asynchronously. + + Args: + src (tensor_descriptor): the source tensor descriptor. + offsets (List[int]): the offsets from the base pointer in the tensor descriptor. + dest (shared_memory_descriptor): the shared memory destination to store the loaded data. + pred (int, optional): Predicate to enable or disable the load. Defaults to 1. + mbarrier (shared_memory_descriptor, optional): The barrier object to signal "arrive" on. + """ + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + pred = _semantic.to_tensor(pred) + pred_handle = pred.handle + mbarrier = _unwrap_if_constexpr(mbarrier) + mbarrier_handle = mbarrier.handle if mbarrier is not None else ttgl.ir.value() + _semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle, pred_handle, + mbarrier_handle) + + +@builtin +def async_store(dest: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], src: shared_memory_descriptor, + mbarrier: shared_memory_descriptor = None, _semantic=None) -> None: + """Store a block of tensor specified in tensor descriptor from shared memory to global memory asynchronously. + + Args: + dest (tensor_descriptor): the destination tensor descriptor. + offsets (List[int]): the offsets from the base pointer in the tensor descriptor. + src (shared_memory_descriptor): the shared memory source to load the data. + mbarrier (shared_memory_descriptor, optional): The barrier object to signal "arrive" on. + """ + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + mbarrier = _unwrap_if_constexpr(mbarrier) + mbarrier_handle = mbarrier.handle if mbarrier is not None else ttgl.ir.value() + _semantic.builder.create_async_tdm_copy_local_to_global(dest.handle, offset_handles, src.handle, mbarrier_handle) + + +@builtin +def async_wait(num_outstanding=0, _semantic=None) -> None: + """Wait for the outstanding asynchronous tensor operations to complete. + + Args: + num_outstanding (int): number of outstanding async tensor operations to wait for. + """ + num_outstanding = _unwrap_if_constexpr(num_outstanding) + _semantic.builder.create_async_tdm_wait(num_outstanding) + + +@builtin +def async_scatter(desc: tensor_descriptor, dst_row_indices: ttgl.tensor, dst_col_offset, src: shared_memory_descriptor, + mbarrier: shared_memory_descriptor = None, _semantic=None) -> None: + """Scatter data from shared memory to non-contiguous rows in global memory asynchronously. + + This operation uses TDM scatter mode to write data to non-contiguous rows in global memory. + Unlike async_store which writes to contiguous rows, scatter allows writing to arbitrary + rows specified by the dst_row_indices tensor. + + The dtype of dst_row_indices determines the index size: + - int16: up to 16 rows can be scattered per TDM instruction + - int32: up to 8 rows can be scattered per TDM instruction + If more rows are needed, multiple TDM instructions will be automatically issued. + + Args: + desc (tensor_descriptor): the destination tensor descriptor. Must be 2D. + dst_row_indices (tensor): 1D tensor of row indices (int16 or int32) in the destination tensor. + dst_col_offset (int or tensor): the starting column offset in the destination tensor + for all scattered rows. + src (shared_memory_descriptor): the shared memory source containing data to scatter. Must be 2D. + mbarrier (shared_memory_descriptor, optional): The barrier object to signal "arrive" on. + """ + ndim = len(desc.block_shape) + assert ndim == 2, f"TDM scatter only supports 2D tensors, got {ndim}D" + + src_ndim = len(src.shape) + assert src_ndim == 2, f"TDM scatter src must be 2D, got {src_ndim}D" + + # Convert dst_col_offset to i32 + dst_col_offset_handle = _semantic._convert_to_ir_values([dst_col_offset], require_i64=False)[0] + + mbarrier = _unwrap_if_constexpr(mbarrier) + mbarrier_handle = mbarrier.handle if mbarrier is not None else ttgl.ir.value() + + _semantic.builder.create_async_tdm_scatter(desc.handle, dst_row_indices.handle, dst_col_offset_handle, src.handle, + mbarrier_handle) + + +@builtin +def prefetch(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], pred: bool = True, + speculative: bool = False, _semantic=None) -> None: + """Prefetches a block of tensor specified in tensor descriptor from global memory into L2. Speculative prefetches can generate more + efficient assembly because they do not require out of bounds checks. However, they are dropped by the hardware if their virtual address translation is not cached. + So speculative should only be set if previous iterations have accessed the same virtual page (e.g. column major) + Args: + src (tensor_descriptor): the source tensor descriptor. + offsets (List[int]): the offsets from the base pointer in the tensor descriptor. + pred (bool, optional): Predicate to enable or disable the prefetch. Defaults to True. + speculative (bool, optional): Whether the prefetch is speculative. Defaults to False. + """ + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + pred = _semantic.to_tensor(pred) + pred_handle = pred.handle + speculative = _unwrap_if_constexpr(speculative) + _semantic.builder.create_tdm_prefetch(src.handle, offset_handles, pred_handle, speculative, False) + + +@builtin +def _test_prefetch_with_offsets(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], pred: bool = True, + speculative: bool = False, _semantic=None) -> ttgl.tensor: + """Test-only prefetch variant that returns offsets for validation.""" + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + pred = _semantic.to_tensor(pred) + pred_handle = pred.handle + speculative = _unwrap_if_constexpr(speculative) + handle = _semantic.builder.create_tdm_prefetch(src.handle, offset_handles, pred_handle, speculative, True) + shape = _semantic.builder.get_shape_from_tensor(handle) + layout = _semantic.builder.get_gluon_layout_from_tensor(handle) + ret_ty = ttgl.distributed_type(ttgl.int64, shape, layout) + tensor = ttgl.tensor(handle, ret_ty) + return tensor diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna3/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna3/__init__.py new file mode 100644 index 0000000000..d435944216 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna3/__init__.py @@ -0,0 +1,17 @@ +from ..._core import builtin +from .._ops import _wmma + +__all__ = ["wmma"] + + +@builtin +def wmma(a, b, acc, _semantic=None): + """ + Computes matrix-multiplication of a * b + acc using AMD WMMA instruction. + + Args: + a (tensor): The operand a to be multiplied. + b (tensor): The operand b to be multiplied. + acc (tensor): The accumulator tensor. + """ + return _wmma(1, a, b, acc, _semantic) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna4/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna4/__init__.py new file mode 100644 index 0000000000..59e3e169bd --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/rdna4/__init__.py @@ -0,0 +1,17 @@ +from ..._core import builtin +from .._ops import _wmma + +__all__ = ["wmma"] + + +@builtin +def wmma(a, b, acc, _semantic=None): + """ + Computes matrix-multiplication of a * b + acc using AMD WMMA instruction. + + Args: + a (tensor): The operand a to be multiplied. + b (tensor): The operand b to be multiplied. + acc (tensor): The accumulator tensor. + """ + return _wmma(2, a, b, acc, _semantic) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/amd/warp_pipeline.py b/third_party/mthreads/python/triton/experimental/gluon/language/amd/warp_pipeline.py new file mode 100644 index 0000000000..6e9dd1f2e0 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/amd/warp_pipeline.py @@ -0,0 +1,62 @@ +from __future__ import annotations + + +class warp_pipeline_stage: + """ + Marks the end of a warp-pipeline stage inside a Gluon kernel. + + When used inside @gl.kernel, exiting the `with` block inserts a + warp-pipeline border in the semantic IR. During lowering, these borders + define pipeline clusters (scf.execute_region), drive dependency analysis, + and determine where conditional and cluster-scope barriers are required. + + The optional string label (e.g., "load", "compute") is attached to the + border op and may be used by downstream passes for diagnostics. + + Example: + @gl.kernel + def gemm(K: gl.i32): + one = gl.const_i32(1) + offs_a = ... + + for k in gl.range(0, K, one): + + # Stage 0: prefetch tiles + with amd.warp_pipeline_stage("load"): + a = gl.amd.buffer_load(a_ptr, offs_a) + b = gl.amd.buffer_load(b_ptr, offs_b) + + # Stage 1: prepare MFMA operands + with amd.warp_pipeline_stage("prep"): + a_tile = a.load(layout=...) + b_tile = b.load(layout=...) + + # Stage 2: compute + with amd.warp_pipeline_stage("compute"): + acc = gl.amd.mfma(a_tile, b_tile, acc) + offs_a += strideA + offs_b += strideB + + """ + + __slots__ = ("label", "_semantic", "str_attr") + + def __init__(self, label=None, **_internal): + self.label = getattr(label, "value", None) + self._semantic = _internal.get("_semantic", None) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if exc_type is not None: + return False + if self._semantic is None: + return False + if self.label is None: + attr = "cluster" + else: + attr = self.label + self._semantic.builder.create_warp_pipeline_border(attr) + + return False diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/extra/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/extra/__init__.py new file mode 100644 index 0000000000..2091e0b7e2 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/extra/__init__.py @@ -0,0 +1,3 @@ +from triton.language.extra import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/__init__.py new file mode 100644 index 0000000000..3ecf36d3b9 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/__init__.py @@ -0,0 +1,4 @@ +from . import blackwell +from . import hopper + +__all__ = ["blackwell", "hopper"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/__init__.py new file mode 100644 index 0000000000..38b012f017 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/__init__.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from triton import knobs +from triton.experimental.gluon.language import _core as ttgl +from triton.experimental.gluon.language._layouts import DotOperandLayout, NVMMADistributedLayout +from ..._core import builtin, _unwrap_if_constexpr +from . import async_copy, mbarrier + +__all__ = ["async_copy", "mbarrier", "mma_v2"] + + +@builtin +def mma_v2(a, b, acc, input_precision=None, _semantic=None): + input_precision = _unwrap_if_constexpr(input_precision) + assert isinstance(a, ttgl.tensor), "a must be a tensor" + assert isinstance(b, ttgl.tensor), "b must be a tensor" + assert isinstance(acc, ttgl.tensor), "acc must be a tensor" + + mma_layout = acc.type.layout + assert isinstance(mma_layout, NVMMADistributedLayout), "acc must have an NVMMADistributedLayout" + assert mma_layout.version == [2, 0], "MMA layout must have version 2.0" + + assert isinstance(a.type.layout, DotOperandLayout), "a must have a DotOperandLayout" + assert isinstance(b.type.layout, DotOperandLayout), "b must have a DotOperandLayout" + assert a.type.layout.parent == mma_layout, "a's parent layout must be the same as acc's layout" + assert b.type.layout.parent == mma_layout, "b's parent layout must be the same as acc's layout" + assert a.type.layout.operand_index == 0, "a's operand index must be 0" + assert b.type.layout.operand_index == 1, "b's operand index must be 1" + + handle = _semantic.dot(a, b, acc, input_precision=input_precision, max_num_imprecise_acc=None, + out_dtype=acc.dtype).handle + return ttgl.tensor(handle, acc.type) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py new file mode 100644 index 0000000000..b6752402bf --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py @@ -0,0 +1,74 @@ +from ..._semantic import _check +from ..._core import _unwrap_if_constexpr, builtin +from triton._C.libtriton import ir + +__all__ = [ + "async_copy_global_to_shared", + "mbarrier_arrive", + "commit_group", + "wait_group", +] + + +@builtin +def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False, + _semantic=None): + """ + Asynchronously copy elements from global memory to shared memory. + + Args: + smem (shared_memory_descriptor): Destination shared memory descriptor. + pointer (tensor): Source pointer tensor. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + eviction_policy (str): Eviction policy specifier. Defaults to "". + volatile (bool): Whether the load is volatile. Defaults to False. + """ + mask = _unwrap_if_constexpr(mask) + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + eviction_policy = _semantic._str_to_eviction_policy(eviction_policy) + volatile = _unwrap_if_constexpr(volatile) + if mask is not None: + pointer, mask = _semantic.broadcast_impl_value(pointer, mask) + _check( + smem.shape == pointer.shape, lambda: + f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}" + ) + mask_handle = mask.handle if mask is not None else ir.value() + _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, ir.value(), + cache_modifier, eviction_policy, volatile) + + +@builtin +def mbarrier_arrive(mbarrier, increment_count=True, _semantic=None): + """ + Arrive on the mbarrier once all outstanding async copies are complete. + + Args: + mbarrier (shared_memory_descriptor): Barrier object to arrive on. + increment_count (bool): Whether to increment the arrival count. Defaults to True. + """ + increment_count = _unwrap_if_constexpr(increment_count) + _semantic.builder.create_async_copy_mbarrier_arrive(mbarrier.handle, increment_count) + + +@builtin +def commit_group(_semantic=None): + """ + Commit the current asynchronous copy group. + + This finalizes a set of asynchronous copy operations. + """ + _semantic.builder.create_async_commit_group() + + +@builtin +def wait_group(num_outstanding=0, _semantic=None): + """ + Wait for outstanding asynchronous copy group operations. + + Args: + num_outstanding (int): Wait until `num_outstanding` or less async copy groups in-flight. Defaults to 0. + """ + num_outstanding = _unwrap_if_constexpr(num_outstanding) + _semantic.builder.create_async_wait_group(num_outstanding) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py new file mode 100644 index 0000000000..0ee39057ae --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py @@ -0,0 +1,121 @@ +import math + +import triton.experimental.gluon.language as ttgl +from triton.experimental.gluon._runtime import constexpr_function, jit +from triton.experimental.gluon.language._layouts import SwizzledSharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +__all__ = ["allocate_mbarrier", "arrive", "init", "invalidate", "MBarrierLayout", "wait"] + + +class MBarrierLayout(SwizzledSharedLayout): + """ + Layout for mbarrier synchronization in Ampere and later architectures. + + Args: + cga_layout (List[List[int]]): CGA layout bases. Defaults to []. + """ + + def __init__(self, cga_layout=None): + super().__init__(vec=1, per_phase=1, max_phase=1, order=[0], cga_layout=cga_layout or []) + + @staticmethod + @constexpr_function + def multicta(num_ctas: int, two_cta: bool = False): + """ + Create a multi-CTA mbarrier layout. + + Args: + num_ctas (int): Number of CTAs. + two_cta (bool): Whether the barrier should synchronize every other CTA + """ + num_ctas = ttgl._unwrap_if_constexpr(num_ctas) + two_cta = ttgl._unwrap_if_constexpr(two_cta) + if two_cta: + assert num_ctas % 2 == 0, "num_ctas must be even for two-CTA mode" + assert num_ctas > 0, "num_ctas must be positive" + assert (num_ctas & (num_ctas - 1)) == 0, "num_ctas must be a power of two" + + bases = [] + if two_cta: + bases.append([0]) + num_ctas //= 2 + + for i in range(int(math.log2(num_ctas))): + bases.append([2**i]) + return MBarrierLayout(bases) + + +@jit +def allocate_mbarrier(batch: ttgl.constexpr = None, two_ctas: ttgl.constexpr = False): + """ + Helper function to allocate an mbarrier + + Args: + two_ctas (bool): Whether the barrier should synchronize every other CTA + """ + num_ctas: ttgl.constexpr = ttgl.num_ctas() + num_elems: ttgl.constexpr = num_ctas if not two_ctas else num_ctas // 2 + ttgl.static_assert(batch is None or isinstance(batch.value, int)) + shape: ttgl.constexpr = [num_elems] if batch is None else [batch, num_elems] + bar = ttgl.allocate_shared_memory( + ttgl.int64, + shape, + MBarrierLayout.multicta(num_ctas=num_ctas, two_cta=two_ctas), + ) + return bar + + +@builtin +def init(mbarrier, count, _semantic=None): + """ + Initialize an mbarrier with a specified count. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to initialize. + count (int): The initial count for the barrier. + """ + count = _unwrap_if_constexpr(count) + _semantic.builder.create_mbarrier_init(mbarrier.handle, count) + + +@builtin +def invalidate(mbarrier, _semantic=None): + """ + Invalidate an mbarrier, resetting its state. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to invalidate. + """ + _semantic.builder.create_mbarrier_inval(mbarrier.handle) + + +@builtin +def wait(mbarrier, phase, pred=True, deps=(), _semantic=None): + """ + Wait until the mbarrier object completes its current phase. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to wait on. + phase (int): The phase index to wait for. + pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True. + deps (Sequence[shared_memory_descriptor]): Dependent allocations barrier is waiting on. Used to track liveness of dependent allocations. Defaults to (). + """ + phase = _semantic.to_tensor(phase) + pred = _semantic.to_tensor(pred) + deps = [x.handle for x in deps] + _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps) + + +@builtin +def arrive(mbarrier, *, pred=True, _semantic=None): + """ + Arrive on an mbarrier, signaling that a thread has reached the barrier. + + Args: + mbarrier (shared_memory_descriptor): The barrier object to arrive on. + pred (bool): Predicate. Operation is skipped if predicate is False. Defaults to True. + """ + count = 1 + pred = _semantic.to_tensor(pred) + _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py new file mode 100644 index 0000000000..031dba0fd0 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -0,0 +1,571 @@ +from __future__ import annotations +from typing import Optional, Tuple, List, TYPE_CHECKING + +from dataclasses import dataclass +from triton.runtime.jit import constexpr_function +from triton.experimental.gluon.language import _core as ttgl +from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr +from triton.experimental.gluon.language._semantic import _check, _compute_tmem_reg_layout + +from . import tma +from ..hopper import fence_async_shared, mbarrier +from ..ampere import async_copy, mma_v2 + +from triton._C.libtriton import ir +import triton._C.libtriton.gluon_ir as gluon_ir +if TYPE_CHECKING: + from triton._C.libtriton.gluon_ir import GluonOpBuilder + from ..._semantic import GluonSemantic + +__all__ = [ + "allocate_tensor_memory", + "async_copy", + "fence_async_shared", + "get_tmem_reg_layout", + "mbarrier", + "mma_v2", + "tensor_memory_descriptor", + "TensorMemoryLayout", + "TensorMemoryScalesLayout", + "tma", + "_TensorMemoryLinearLayout", +] + + +@dataclass(frozen=True, eq=True) +class TensorMemoryLayout: + """ + Describes the layout for tensor memory in Blackwell architecture. + + Args: + block (Tuple[int, int]): Number of contiguous elements per row / column in a CTA. + col_stride (int): Number of 32-bit columns to advance between logically + adjacent columns. Packed layouts use a stride of 1. Unpacked + layouts use ``32 / bitwidth``. + cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None. + two_ctas (bool): Whether the layout is for two-CTA mode. Defaults to False. + """ + block: Tuple[int, int] + col_stride: int + cta_split_num: Optional[Tuple[int, int]] = None + two_ctas: bool = False + + def __post_init__(self): + super().__setattr__("block", _unwrap_if_constexpr(self.block)) + super().__setattr__("col_stride", _unwrap_if_constexpr(self.col_stride)) + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) + super().__setattr__("two_ctas", _unwrap_if_constexpr(self.two_ctas)) + assert len(self.block) == 2 + assert self.cta_split_num is None or len(self.cta_split_num) == 2 + assert self.col_stride >= 1 and (self.col_stride & + (self.col_stride - 1)) == 0, "tensor memory col_stride must be a power of two" + + def _to_ir(self, builder): + cta_split_num = list(self.cta_split_num) if self.cta_split_num else [1, 1] + return builder.get_tensor_memory_layout( + self.block, + self.col_stride, + cta_split_num, + self.two_ctas, + ) + + def mangle(self) -> str: + block_str = f"{self.block[0]}x{self.block[1]}" + stride_str = f"C{self.col_stride}" + cta_split_str = (f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "") + two_ctas_str = "2CT" if self.two_ctas else "" + return f"TL{block_str}{stride_str}{cta_split_str}{two_ctas_str}TL" + + def __hash__(self): + return hash((self.block, self.col_stride, self.cta_split_num, self.two_ctas)) + + +@dataclass(frozen=True, eq=True) +class TensorMemoryScalesLayout: + """ + Describes the layout for tensor memory scales in Blackwell architecture. + + Args: + cta_split_num (Optional[Tuple[int, int]]): CTA split factors. Defaults to None. + """ + cta_split_num: Optional[Tuple[int, int]] = None + + def __post_init__(self): + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) + assert self.cta_split_num is None or len(self.cta_split_num) == 2 + + def _to_ir(self, builder): + cta_split_num = list(self.cta_split_num) if self.cta_split_num else [1, 1] + return builder.get_tensor_memory_scales_layout(cta_split_num) + + def mangle(self) -> str: + cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "" + return f"TLS{cta_split_str}TLS" + + def __hash__(self): + return hash(self.cta_split_num) + + +@dataclass(frozen=True) +class _TensorMemoryLinearLayout: + """ + Print-only linear layout for TMEM (row/col -> dim0/dim1). + """ + rows: List[List[int]] + cols: List[List[int]] + shape: List[int] + + def _to_ir(self, builder): + raise RuntimeError("TensorMemoryLinearLayout is print-only; IR materialization is unsupported") + + def mangle(self): + return f"TMLL_{self.shape}_TMLL" + + def __hash__(self): + return hash((tuple(map(tuple, self.rows)), tuple(map(tuple, self.cols)), tuple(self.shape))) + + +@constexpr_function +def get_tmem_reg_layout( + element_ty, + shape, + layout, + num_warps, + instr_variant="32x32b", + cga_layout=(), +): + """ + Returns a DistributedLinearLayout compatible with TMEM load/store instructions. + + Args: + element_ty (dtype): Element type stored in tensor memory. + shape (Sequence[int]): Global tensor shape addressed by the TMEM descriptor. + layout (TensorMemoryLayout): Tensor memory layout descriptor. + num_warps (int): Number of warps participating in the operation. + instr_variant (str): TMEM instruction variant (e.g. ``\"32x32b\"``). + cga_layout (Sequence[Sequence[int]]): CGA layout bases describing CTA distribution. + """ + + def _unwrap(x): + if isinstance(x, ttgl.constexpr): + return _unwrap(x.value) + if isinstance(x, list): + return [_unwrap(i) for i in x] + if isinstance(x, tuple): + return tuple(_unwrap(i) for i in x) + return x + + return _compute_tmem_reg_layout( + _unwrap(element_ty), + _unwrap(shape), + _unwrap(layout), + _unwrap(num_warps), + _unwrap(instr_variant), + _unwrap(cga_layout), + ) + + +class tensor_memory_descriptor_type(base_type): + + def __init__(self, element_ty, shape, layout, alloc_shape): + self.element_ty = element_ty + self.shape = shape + self.layout = layout + self.alloc_shape = alloc_shape + assert isinstance(layout, TensorMemoryLayout) or isinstance(layout, TensorMemoryScalesLayout) + + def to_ir(self, builder: GluonOpBuilder) -> None: + return builder.get_tensor_mem_desc_ty( + self.element_ty.to_ir(builder), + self.shape, + self.layout._to_ir(builder), + self.alloc_shape, + ) + + def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]: + value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def __str__(self) -> str: + return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>" + + def __eq__(self, other) -> bool: + return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout + and self.alloc_shape == other.alloc_shape) + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + shape_str = "_".join([str(s) for s in self.shape]) + return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD" + + +class tensor_memory_descriptor(base_value): + """ + Represents a tensor memory descriptor handle for Tensor Core Gen5 operations. + """ + + def __init__(self, handle, element_ty, shape, layout, alloc_shape): + self.handle = handle + self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def dtype(self): + return self.type.element_ty + + @property + def shape(self): + return self.type.shape + + @property + def rank(self): + return len(self.shape) + + @property + def layout(self): + return self.type.layout + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, layout, _semantic: GluonSemantic = None) -> ttgl.tensor: + """ + Load a tensor from tensor memory. + + Args: + layout (DistributedLayout): Destination layout of the tensor. + + Returns: + tensor: A distributed tensor containing the loaded data. + """ + layout = _unwrap_if_constexpr(layout) + ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout) + builder = _semantic.builder + handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle) + return ttgl.tensor(handle, ret_ty) + + def _load_red(self, layout, red_op, abs, propagate_nan, _semantic: GluonSemantic): + # red_op: MIN/MAX reduction operation + # abs (bool): If True, reduce absolute values. + # propagate_nan (NONE): If ALL, propagate NaN in specified reduction operation. + layout = _unwrap_if_constexpr(layout) + abs_flag = _unwrap_if_constexpr(abs) + propagate_nan = _unwrap_if_constexpr(propagate_nan) + + ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout) + builder = _semantic.builder + + result, reduced, red_layout = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle, red_op, abs_flag, + propagate_nan) + + red_shape = [self.shape[0]] # [M] for [M,N] input + red_ty = ttgl.distributed_type(self.dtype, red_shape, red_layout) + + return (ttgl.tensor(result, ret_ty), ttgl.tensor(reduced, red_ty)) + + @builtin + def load_min(self, layout, abs=False, propagate_nan=ir.PROPAGATE_NAN.NONE, _semantic: GluonSemantic = None): + """ + Load a tensor from tensor memory with MIN reduction along the N-dimension. + + Args: + layout (DistributedLayout): Destination layout of the tensor. + abs (bool): If True, reduce absolute values. Defaults to False. + propagate_nan (PROPAGATE_NAN): If ALL, propagate NaN in the reduction operation. Defaults to NONE. + + Returns: + tuple: A tuple containing (tensor, reduced_tensor) where tensor is the loaded data + and reduced_tensor is the result of MIN reduction along the N-dimension of loaded data + """ + return self._load_red(layout, gluon_ir.TMEM_LOAD_REDUCE_MODIFIER.MIN, abs, propagate_nan, _semantic) + + @builtin + def load_max(self, layout, abs=False, propagate_nan=ir.PROPAGATE_NAN.NONE, _semantic: GluonSemantic = None): + """ + Load a tensor from tensor memory with MAX reduction along the N-dimension. + + Args: + layout (DistributedLayout): Destination layout of the tensor. + abs (bool): If True, reduce absolute values. Defaults to False. + propagate_nan (PROPAGATE_NAN): If ALL, propagate NaN in the reduction operation. Defaults to NONE. + + Returns: + tuple: A tuple containing (tensor, reduced_tensor) where tensor is the loaded data + and reduced_tensor is the result of MAX reduction along the N-dimension of loaded data. + """ + return self._load_red(layout, gluon_ir.TMEM_LOAD_REDUCE_MODIFIER.MAX, abs, propagate_nan, _semantic) + + @builtin + def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None: + """ + Store a tensor into tensor memory. + + Args: + value (tensor): The tensor to store. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + """ + pred = _unwrap_if_constexpr(pred) + pred = _semantic.to_tensor(pred) + assert value.shape == self.shape, f"source shape {value.shape} does not match destination shape {self.shape}" + assert value.dtype == self.dtype, f"source dtype {value.dtype} does not match destination dtype {self.dtype}" + _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle) + + @builtin + def slice(self, start, length, _semantic: GluonSemantic = None) -> None: + """ + Create a slice of the tensor memory descriptor along the last dimension. + + Args: + start (int): The starting index for subslice. + length (int): The length of the subslice. + + Returns: + tensor_memory_descriptor: Descriptor for the subslice. + """ + start = _unwrap_if_constexpr(start) + length = _unwrap_if_constexpr(length) + _check(isinstance(start, int), lambda: "start must be a constant int") + _check(isinstance(length, int), lambda: "length must be a constant int") + shape = self.shape[:-1] + [length] + layout = self.type.layout + layout = TensorMemoryLayout( + (layout.block[0], min(layout.block[1], length)), + layout.col_stride, + layout.cta_split_num, + layout.two_ctas, + ) + ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) + builder = _semantic.builder + ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start) + return ret + + @builtin + def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor: + """ + Create a subview of tensor memory by indexing the first dimension. + + Args: + index (tensor): The index tensor for the subview. + + Returns: + tensor_memory_descriptor: Descriptor for the indexed subview. + """ + index = _semantic.to_tensor(index) + builder = _semantic.builder + shape = self.shape[1:] + layout = self.layout + ret = tensor_memory_descriptor(None, self.dtype, shape, layout, shape) + ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle) + return ret + + @builtin + def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor: + """ + Reinterpret tensor memory descriptor with a new dtype, shape, and layout. + + Args: + dtype (dtype): The new data type. + shape (Sequence[int]): The new shape. + layout (TensorMemoryLayout): The new layout. + + Returns: + tensor_memory_descriptor: Descriptor with updated type and layout. + """ + dtype = _unwrap_if_constexpr(dtype) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + + ty = tensor_memory_descriptor_type(dtype, shape, layout, shape) + handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle) + return tensor_memory_descriptor(handle, **ty.__dict__) + + +@builtin +def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None): + """ + Allocate tensor memory. + + Args: + element_ty (dtype): The element data type. + shape (Sequence[int]): The descriptor shape. + layout (TensorMemoryLayout): The layout of the tensor memory. + value (tensor, optional): Initial tensor to copy. Defaults to None. + + Returns: + tensor_memory_descriptor: Descriptor for the allocated memory. + """ + element_ty = _unwrap_if_constexpr(element_ty) + shape = _unwrap_if_constexpr(shape) + layout = _unwrap_if_constexpr(layout) + value = value.handle if value is not None else None + + ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape) + builder = _semantic.builder + handle = builder.create_tmem_alloc(ty.to_ir(builder), value) + return tensor_memory_descriptor(handle, element_ty, shape, layout, shape) + + +@builtin +def tcgen05_copy(src, dst, _semantic=None): + """ + Start an asynchronous copy from shared memory to tensor memory. + + Args: + src (shared_memory_descriptor): Shared memory to copy from. + dst (tensor_memory_descriptor): Tensor memory to copy to. + """ + assert isinstance(src, ttgl.shared_memory_descriptor), "source must be a shared memory descriptor" + assert isinstance(dst, tensor_memory_descriptor), "destination must be a tensor memory descriptor" + _semantic.builder.create_tmem_copy(src.handle, dst.handle) + + +@builtin +def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, multicast=False, mbarriers=None, mbarrier_preds=None, + _semantic=None): + """ + Emit a 5th generation TensorCore MMA instruction. + acc = a * b + (acc if use_acc else 0) + + Args: + a (shared_memory_descriptor): Left hand side operand in shared memory. + b (shared_memory_descriptor or tensor_memory_descriptor): Right hand side operand in shared or tensor memory. + acc (tensor_memory_descriptor): Accumulator value in tensor memory (mutated). + use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + multicast (bool): Whether tcgen05 commit should multicast across a CTA cluster. Defaults to False. + mbarriers (Sequence[shared_memory_descriptor], optional): Barriers to signal when the operation is complete. If None, mma is synchronous. Defaults to None. + mbarrier_preds (Sequence[bool], optional): Predicates for barriers. Defaults to None. + """ + use_acc = _semantic.to_tensor(use_acc) + pred = _semantic.to_tensor(pred) + + if mbarriers is None: + assert mbarrier_preds is None + mbarriers = [] + mbarrier_preds = [] + else: + mbarriers = [bar.handle for bar in mbarriers] + if mbarrier_preds is None: + true = _semantic.to_tensor(True) + mbarrier_preds = [true.handle] * len(mbarriers) + else: + mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False) + + multicast = _unwrap_if_constexpr(multicast) + _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers, + mbarrier_preds, acc.layout.two_ctas, multicast) + + +@builtin +def tcgen05_mma_scaled(a, b, acc, a_scale, b_scale, a_type, b_type, *, use_acc=True, pred=True, mbarriers=None, + mbarrier_preds=None, _semantic=None): + """ + Emit a 5th generation TensorCore MMA scaled instruction. + acc = (a * a_scale) * (b * b_scale) + (acc if use_acc else 0) + + Args: + a (shared_memory_descriptor): Left hand side operand in shared memory. + b (shared_memory_descriptor or tensor_memory_descriptor): Right hand side operand in shared or tensor memory. + acc (tensor_memory_descriptor): Accumulator value in tensor memory (mutated). + a_scale (tensor): Scale factor for operand A. + b_scale (tensor): Scale factor for operand B. + a_type (str): Type of operand A. One of {"e2m1", "e4m3", "e5m2"}. + b_type (str): Type of operand B. One of {"e2m1", "e4m3", "e5m2"}. + use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + mbarriers (Sequence[mbarrier], optional): Barriers to signal when the operation is complete. If None, mma is synchronous. Defaults to None. + mbarrier_preds (Sequence[bool], optional): Predicates for barriers. Defaults to None. + """ + use_acc = _semantic.to_tensor(use_acc) + pred = _semantic.to_tensor(pred) + assert acc.type.layout.block[0] != 64, "tcgen05_mma_scaled does not support blockM=64" + + if mbarriers is None: + assert mbarrier_preds is None + mbarriers = [] + mbarrier_preds = [] + else: + mbarriers = [bar.handle for bar in mbarriers] + if mbarrier_preds is None: + true = _semantic.to_tensor(True) + mbarrier_preds = [true.handle] * len(mbarriers) + else: + mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False) + + allowed_formats = {"e2m1", "e4m3", "e5m2"} + assert a_type.value in allowed_formats, f"Unsupported lhs_format: {a_type.value}" + assert b_type.value in allowed_formats, f"Unsupported rhs_format: {b_type.value}" + a_type = _semantic._str_to_fp_type(a_type.value) + b_type = _semantic._str_to_fp_type(b_type.value) + _semantic.builder.create_tcgen05_mma_scaled(a.handle, b.handle, acc.handle, a_scale.handle, b_scale.handle, a_type, + b_type, use_acc.handle, pred.handle, mbarriers, mbarrier_preds) + + +@constexpr_function +def tcgen05_mma_barrier_count(smems, multicast): + """ + Calculate the number of CTAs that will commit the tcgen05 MMA instruction. + + Args: + smems (Sequence[shared_memory_descriptor]): Shared memory descriptors used in the tcgen05 instruction. + multicast (bool): Whether the tcgen05 instruction is multicast. + + Returns: + int: The number of CTAs that will commit the tcgen05 MMA instruction. + """ + assert 0 <= len(smems) <= 2, "tcgen05_mma_barrier_count supports 0, 1, or 2 smem descriptors" + if not smems or not multicast: + return 1 + + def basis_is_zero(basis): + return all(b == 0 for b in basis) + + def num_broadcast_bits(smem): + return sum(basis_is_zero(basis) for basis in smem.layout.cga_layout) + + if len(smems) == 1: + return 2**num_broadcast_bits(smems[0]) + + assert len(smems) == 2 + num_broadcast_bits_a = num_broadcast_bits(smems[0]) + num_broadcast_bits_b = num_broadcast_bits(smems[1]) + # Asser that for every basis, at least one of them is non-zero + # so that the inclusion-exclusion principle below works + # This can be generalised if needed by substracting below 2**size_intersection + for i in range(len(smems[0].layout.cga_layout)): + assert not basis_is_zero(smems[0].layout.cga_layout[i]) or not basis_is_zero(smems[1].layout.cga_layout[i]) + + # Inclusion-exclusion + num_cta_commits = 2**num_broadcast_bits_a + 2**num_broadcast_bits_b - 1 + return num_cta_commits + + +@builtin +def tcgen05_commit(barrier, pred=True, descs=(), _semantic=None): + """ + This instruction causes the provided mbarrier to be arrived-on with a count + of 1 when all async tcgen05 MMA and copy instructions previously issued by + the thread are complete. + + If `descs` are provided, the commit will be multicast across the CTA cluster + based on the shared layouts of those descriptors. This should be used when + the inputs to the tcgen5 MMA come from TMA descriptors using multicast. + + Args: + barrier (shared_memory_descriptor): The barrier to track completion of tcgen05 MMA and copy instructions. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + descs (Sequence[shared_memory_descriptor]): Shared memory descriptors for + the preceding multiplication inputs. Defaults to (). + """ + pred = _semantic.to_tensor(pred) + descs = _unwrap_if_constexpr(descs) + descs = [d.handle for d in descs] + _semantic.builder.create_tcgen05_commit(barrier.handle, pred.handle, descs) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py new file mode 100644 index 0000000000..c06b103f36 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py @@ -0,0 +1,172 @@ +from triton.language.core import _aggregate as aggregate +from triton.experimental.gluon.language import _core as ttgl, _standard as stdlib +from triton.experimental.gluon._runtime import constexpr_function, jit + +__all__ = [ + "pack2", + "unpack2", + "pack", + "unpack", + "fma", + "Float2Tensor", +] + + +@jit +def _add_f32x2(a, b): + return ttgl.inline_asm_elementwise( + """ + add.f32x2 $0, $1, $2; + """, + "=l,l,l", + [a, b], + dtype=ttgl.int64, + is_pure=True, + pack=1, + ) + + +@jit +def _sub_f32x2(a, b): + return ttgl.inline_asm_elementwise( + """ + sub.f32x2 $0, $1, $2; + """, + "=l,l,l", + [a, b], + dtype=ttgl.int64, + is_pure=True, + pack=1, + ) + + +@jit +def _mul_f32x2(a, b): + return ttgl.inline_asm_elementwise( + """ + mul.f32x2 $0, $1, $2; + """, + "=l,l,l", + [a, b], + dtype=ttgl.int64, + is_pure=True, + pack=1, + ) + + +@jit +def _fma_f32x2(a, b, c): + return ttgl.inline_asm_elementwise( + """ + fma.rn.f32x2 $0, $1, $2, $3; + """, + "=l,l,l,l", + [a, b, c], + dtype=ttgl.int64, + is_pure=True, + pack=1, + ) + + +@aggregate +class Float2Tensor: + value: ttgl.tensor + + @constexpr_function + def __init__(self, value: ttgl.tensor): + self.value = value + + @jit + def __add__(self, rhs): + ttgl.static_assert(isinstance(rhs, Float2Tensor), "rhs must be a Float2Tensor") + return Float2Tensor(_add_f32x2(self.value, rhs.value)) + + @jit + def __sub__(self, rhs): + ttgl.static_assert(isinstance(rhs, Float2Tensor), "rhs must be a Float2Tensor") + return Float2Tensor(_sub_f32x2(self.value, rhs.value)) + + @jit + def __mul__(self, rhs): + ttgl.static_assert(isinstance(rhs, Float2Tensor), "rhs must be a Float2Tensor") + return Float2Tensor(_mul_f32x2(self.value, rhs.value)) + + @jit + def sum(self, axis: ttgl.constexpr): + return Float2Tensor(ttgl.reduce(self.value, axis=axis, combine_fn=_add_f32x2)) + + +@jit +def pack2(x0, x1): + value = ttgl.inline_asm_elementwise( + """ + mov.b64 $0, { $1, $2 }; + """, + "=l,r,r", + [x0, x1], + dtype=ttgl.int64, + is_pure=True, + pack=1, + ) + return Float2Tensor(value) + + +@jit +def unpack2(x): + return ttgl.inline_asm_elementwise( + """ + mov.b64 { $0, $1 }, $2; + """, + "=r,=r,l", + [x.value], + dtype=[ttgl.float32, ttgl.float32], + is_pure=True, + pack=1, + ) + + +@constexpr_function +def _get_split_shape(shape, axis): + shape = [d for d in shape] + assert shape[axis] >= 2, f"not enough elements to pack along axis {axis}" + shape[axis] //= 2 + shape.insert(axis + 1, 2) + permute = list(range(len(shape))) + permute[axis + 1], permute[len(permute) - 1] = permute[len(permute) - 1], permute[axis + 1] + return ttgl.tuple(shape), ttgl.tuple(permute) + + +@constexpr_function +def _get_join_shape(shape, axis): + shape = [d for d in shape] + shape[axis] *= 2 + permute = list(range(len(shape))) + permute.insert(axis + 1, len(permute)) + return ttgl.tuple(shape), ttgl.tuple(permute) + + +@jit +def pack(x, axis): + sp: ttgl.constexpr = _get_split_shape(x.shape, axis) + x0, x1 = x.reshape(*sp[0]).permute(*sp[1]).split() + return pack2(x0, x1) + + +@jit +def unpack(x, axis): + shape: ttgl.constexpr = x.value.shape + sp: ttgl.constexpr = _get_join_shape(shape, axis) + x0, x1 = unpack2(x) + return ttgl.join(x0, x1).permute(*sp[1]).reshape(*sp[0]) + + +@jit +def full_like(x, fill_value): + ttgl.static_assert(fill_value.dtype == ttgl.float32, "fill_value must be a float32") + fill = stdlib.full_like(x.value, fill_value, dtype=ttgl.float32) + return pack2(fill, fill) + + +@jit +def fma(a, b, c): + return Float2Tensor(_fma_f32x2(a.value, b.value, c.value)) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py new file mode 100644 index 0000000000..01adc77200 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -0,0 +1,74 @@ +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._core import builtin +from triton.experimental.gluon.language.nvidia.hopper.tma import ( + async_copy_global_to_shared, + async_copy_shared_to_global, + store_wait, + tensor_descriptor, + tensor_descriptor_type, + make_tensor_descriptor, + _emit_alignment_check, +) + +__all__ = [ + "async_gather", + "async_scatter", + "async_copy_global_to_shared", + "async_copy_shared_to_global", + "store_wait", + "tensor_descriptor", + "tensor_descriptor_type", + "make_tensor_descriptor", +] + + +@builtin +def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None): + """ + Asynchronously gather elements from global memory to shared memory using TMA. + + Args: + tensor_desc (tensor_descriptor): The tensor descriptor. + x_offsets (tensor): 1D tensor of X offsets. + y_offset (int): Scalar Y offset. + barrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete. + result (tensor_memory_descriptor): Result shared memory, must have NVMMASharedLayout. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + """ + if _semantic.builder.options.enable_iisan: + _emit_alignment_check(tensor_desc, (y_offset, ), "async_gather", "y_offset", _semantic=_semantic) + + pred = _semantic.to_tensor(pred) + y_offset = _semantic.to_tensor(y_offset) + _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle, + result.handle, pred.handle) + + +def _emit_scatter_nonnegative_check(x_offsets, y_offset, _semantic=None): + y_offset = ttgl.to_tensor(y_offset, _semantic=_semantic) + zero = ttgl.to_tensor(0, _semantic=_semantic) + + is_nonnegative = y_offset.__ge__(zero, _semantic=_semantic) + ttgl.device_assert(is_nonnegative, "async_scatter y_offset cannot be negative", _semantic=_semantic) + + is_nonnegative = x_offsets.__ge__(zero, _semantic=_semantic) + ttgl.device_assert(is_nonnegative, "async_scatter x_offsets cannot have any negative elements", _semantic=_semantic) + + +@builtin +def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None): + """ + Asynchronously scatter elements from shared memory to global memory using TMA. + + Args: + tensor_desc (tensor_descriptor): The tensor descriptor. + x_offsets (tensor): 1D tensor of X offsets. + y_offset (int): Scalar Y offset. + src (tensor_memory_descriptor): The source data, must be in NVMMASharedLayout. + """ + if _semantic.builder.options.enable_iisan: + _emit_alignment_check(tensor_desc, (y_offset, ), "async_scatter", "y_offset", _semantic=_semantic) + _emit_scatter_nonnegative_check(x_offsets, y_offset, _semantic=_semantic) + + y_offset = _semantic.to_tensor(y_offset) + _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py new file mode 100644 index 0000000000..8285b82fbe --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py @@ -0,0 +1,141 @@ +from __future__ import annotations +from triton.compiler.code_generator import unflatten_ir_values +from ..ampere import async_copy, mma_v2 +from . import cluster, mbarrier, tma +from ... import _core + +from typing import List, Tuple, TYPE_CHECKING +if TYPE_CHECKING: + from triton._C.libtriton import ir + +__all__ = [ + "async_copy", + "cluster", + "fence_async_shared", + "mbarrier", + "mma_v2", + "tma", + "warpgroup_mma", + "warpgroup_mma_wait", +] + + +@_core.builtin +def fence_async_shared(cluster=False, _semantic=None): + """ + Issue a fence to complete asynchronous shared memory operations. + + Args: + cluster (bool): Whether to fence across cluster. Defaults to False. + """ + cluster = _core._unwrap_if_constexpr(cluster) + _semantic.builder.create_fence_async_shared(cluster) + + +class warpgroup_mma_accumulator_type(_core.base_type): + tensor_type: _core.dtype + + def __init__(self, tensor_type: _core.dtype): + self.tensor_type = tensor_type + + def __str__(self) -> str: + return f"warpgroup_mma_accumulator<{self.tensor_type}>" + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]: + return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + self.tensor_type._flatten_ir_types(builder, out) + + def __eq__(self, other) -> bool: + return type(self) is type(other) and self.tensor_type == other.tensor_type + + def mangle(self) -> str: + return f"FT{self.tensor_type.mangle()}FT" + + +class warpgroup_mma_accumulator(_core.base_value): + handle: ir.value + type: warpgroup_mma_accumulator_type + + def __init__(self, handle, tensor_type: _core.dtype): + self.handle = handle + self.type = warpgroup_mma_accumulator_type(tensor_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + +@_core.builtin +def warpgroup_mma_init(value, _semantic=None): + assert isinstance(value, _core.tensor) + return warpgroup_mma_accumulator(value.handle, value.type) + + +@_core.builtin +def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False, + _semantic=None): + """ + Perform warpgroup MMA (Tensor Core) operations. + acc = a * b + (acc if use_acc else 0) + + Args: + a (tensor or shared_memory_descriptor): Left hand side operand. + b (shared_memory_descriptor): Right hand side operand. + acc (tensor): Accumulator tensor. + use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True. + precision (str, optional): Dot input precision. Defaults to builder default. + max_num_imprecise_acc (int): Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done. + is_async (bool): Whether operation is asynchronous. Defaults to False. + + Returns: + tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous. + """ + use_acc = _semantic.to_tensor(use_acc) + + if precision is None: + precision = _semantic.builder.options.default_dot_input_precision + + precision = _semantic._str_to_dot_input_precision(precision) + + K = a.type.shape[-1] + if max_num_imprecise_acc is None: + if a.dtype.is_fp8() and b.dtype.is_fp8(): + max_num_imprecise_acc = _semantic.builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + else: + if a.dtype.is_fp8() and b.dtype.is_fp8() and max_num_imprecise_acc > K: + raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") + + max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc) + is_async = _core._unwrap_if_constexpr(is_async) + + handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision, + max_num_imprecise_acc, is_async) + tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type + if is_async: + return warpgroup_mma_accumulator(handle, tensor_ty) + else: + return _core.tensor(handle, tensor_ty) + + +@_core.builtin +def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None): + """ + Wait until `num_outstanding` or less warpgroup MMA operations are in-flight. + + Args: + num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0. + deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished. + """ + if deps is None: + raise ValueError("warpgroup_mma_wait deps must be given") + deps_handles = [x.handle for x in deps] if deps is not None else [] + num_outstanding = _core._unwrap_if_constexpr(num_outstanding) + results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding) + result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps] + results = unflatten_ir_values(results, result_types) + if len(deps) == 1: + return next(results) + return tuple(results) diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/cluster.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/cluster.py new file mode 100644 index 0000000000..13c9a96d43 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/cluster.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +__all__ = ["arrive", "wait"] + + +@builtin +def arrive(relaxed: bool = False, _semantic=None): + """ + Arrive at a barrier that synchronizes across the CTA cluster. + + Args: + relaxed (bool): Whether to use relaxed semantics. Defaults to False. + """ + relaxed = _unwrap_if_constexpr(relaxed) + _semantic.builder.create_cluster_arrive(relaxed) + + +@builtin +def wait(_semantic=None): + """ + Wait for all CTAs in the cluster to arrive at the cluster barrier. + """ + _semantic.builder.create_cluster_wait() diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py new file mode 100644 index 0000000000..e37b4c4ed7 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py @@ -0,0 +1,66 @@ +from ..ampere.mbarrier import MBarrierLayout, allocate_mbarrier, init, invalidate, wait +from triton.experimental.gluon._runtime import jit +from ..._core import _unwrap_if_constexpr, builtin +from . import cluster + +__all__ = [ + "allocate_mbarrier", + "arrive", + "expect", + "sync_cluster_init", + "fence_init_release_cluster", + "init", + "invalidate", + "MBarrierLayout", + "wait", +] + + +@builtin +def expect(mbarrier, bytes_per_cta=None, pred=True, _semantic=None): + """ + Expect a specific number of bytes being copied. When they are copied, the barrier is signaled. + + Args: + mbarrier (shared_memory_descriptor): Barrier that will be signaled when the operation is complete. + bytes_per_cta (int): Expected byte count per CTA. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + """ + pred = _semantic.to_tensor(pred) + bytes_per_cta = _unwrap_if_constexpr(bytes_per_cta) + _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes_per_cta, pred.handle) + + +@builtin +def arrive(mbarrier, *, count=1, pred=True, _semantic=None): + """ + Arrive at an mbarrier with a specified count. + + Args: + mbarrier (shared_memory_descriptor): Barrier to be signalled. + count (int): Count to arrive with. Defaults to 1. + pred (bool): Scalar predicate. Operation is skipped if predicate is False. Defaults to True. + """ + count = _unwrap_if_constexpr(count) + pred = _semantic.to_tensor(pred) + _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle) + + +@builtin +def fence_init_release_cluster(_semantic=None): + """ + Fence that makes prior mbarrier initialization visible across the CTA cluster. + + Needs to be called together with cluster.arrive(relaxed=True) and cluster.wait. + """ + _semantic.builder.create_fence_mbarrier_init_release_cluster() + + +@jit +def sync_cluster_init(): + """ + Ensure mbarrier initialization is visible across the CTA cluster. + """ + fence_init_release_cluster() + cluster.arrive(relaxed=True) + cluster.wait() diff --git a/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/tma.py new file mode 100644 index 0000000000..d21d9bf922 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -0,0 +1,218 @@ +from __future__ import annotations +from typing import List, Tuple, TYPE_CHECKING +from dataclasses import dataclass +from triton.language.core import base_type, base_value +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._layouts import NVMMASharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +if TYPE_CHECKING: + from triton._C import ir + +__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"] + + +@dataclass(eq=True) +class tensor_descriptor_type(base_type): + block_type: ttgl.block_type + shape_type: ttgl.tuple_type + strides_type: ttgl.tuple_type + layout: NVMMASharedLayout + + def __str__(self) -> str: + return f"tensor_descriptor<{self.block_type}, {self.layout}>" + + @property + def nbytes_per_cta(self) -> int: + cga_layout = self.layout.cga_layout + if len(cga_layout) == 0: + return self.block_type.nbytes + num_cta_splits = 2**sum(any(x != 0 for x in basis) for basis in cga_layout) + return self.block_type.nbytes // num_cta_splits + + def _to_ir(self, builder: ir.builder) -> ir.type: + is_signed = self.block_type.element_ty.is_int_signed() + return builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + ty = builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + out.append(ty) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD" + + +class tensor_descriptor(base_value): + + def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type, + layout: NVMMASharedLayout): + self.handle = handle + self.shape = ttgl.tuple(shape) + self.strides = ttgl.tuple(strides) + self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type, + layout=layout) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + @property + def nbytes_per_cta(self): + return self.type.nbytes_per_cta + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + @property + def layout(self): + return self.type.layout + + +def _emit_alignment_check(desc, coord, fn_name: str, arg_name: str, _semantic=None): + coord = list(coord)[-1] + align_bytes = 16 + if desc.layout.fp4_padded: + align_bytes = 64 + dtype = desc.dtype + assert dtype.primitive_bitwidth % 8 == 0, f"unexpected sub-byte dtype {dtype}" + elem_bytes = dtype.primitive_bitwidth // 8 + align = align_bytes // elem_bytes + + align_val = ttgl.to_tensor(align, _semantic=_semantic) + zero = ttgl.to_tensor(0, _semantic=_semantic) + + coord = ttgl.to_tensor(coord, _semantic=_semantic) + rem = coord.__mod__(align_val, _semantic=_semantic) + is_zero = rem.__eq__(zero, _semantic=_semantic) + + fp4_padded = "with fp4_padded=True " if desc.layout.fp4_padded else "" + ttgl.device_assert( + is_zero, f"{fn_name} {fp4_padded}{arg_name} must be {align_bytes}-byte aligned, " + f"i.e. a multiple of {align} for dtype={dtype.codegen_name()}", _semantic=_semantic) + + +@builtin +def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, multicast=False, _semantic=None): + if _semantic.builder.options.enable_iisan: + _emit_alignment_check(tensor_desc, coord, "async_copy_global_to_shared", "innermost coordinate", + _semantic=_semantic) + + coord = _semantic._convert_to_ir_values(coord, require_i64=False) + pred = _semantic.to_tensor(pred) + multicast = _unwrap_if_constexpr(multicast) + _semantic.builder.create_async_tma_copy_global_to_local( + tensor_desc.handle, + coord, + barrier.handle, + result.handle, + pred.handle, + multicast, + ) + + +@builtin +def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): + if _semantic.builder.options.enable_iisan: + _emit_alignment_check(tensor_desc, coord, "async_copy_shared_to_global", "innermost coordinate", + _semantic=_semantic) + coord = _semantic._convert_to_ir_values(coord, require_i64=False) + _semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle) + + +@builtin +def store_wait(pendings, _semantic=None): + pendings = _unwrap_if_constexpr(pendings) + _semantic.builder.create_async_tma_store_wait(pendings) + + +@builtin +def make_tensor_descriptor( + base: ttgl.tensor, + shape: List[ttgl.tensor], + strides: List[ttgl.tensor], + block_shape: List[ttgl.constexpr], + layout: NVMMASharedLayout, + padding_option="zero", + _semantic=None, +) -> tensor_descriptor: + padding_option = _unwrap_if_constexpr(padding_option) + block_shape = _unwrap_if_constexpr(block_shape) + + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(block_shape)}") + assert isinstance(base.dtype, ttgl.pointer_type) + elem_size = base.dtype.element_ty.primitive_bitwidth // 8 + contig_dim_size = ttgl._unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + last_stride = ttgl._unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + shape = [_semantic.make_scalar(x, ttgl.int32) for x in shape] + strides = [_semantic.make_scalar(ttgl._unwrap_if_constexpr(x), ttgl.int64) for x in strides] + + # Check whether `block_shape` is static + block_shape = ttgl._unwrap_shape(block_shape) + + assert isinstance(base.type, ttgl.pointer_type) + block_type = ttgl.block_type(base.type.element_ty, block_shape) + base_handle = base.handle + + padding = _semantic._str_to_padding_option(padding_option) + + layout = _unwrap_if_constexpr(layout) + assert isinstance(layout, NVMMASharedLayout), \ + "Expected layout to be a NVMMASharedLayout" + + shape_type = ttgl.tuple(shape).type + strides_type = ttgl.tuple(strides).type + ty = tensor_descriptor_type(block_type, shape_type, strides_type, layout) + + if base.type.element_ty.is_int() and padding == ttgl.ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer blocks") + handle = _semantic.builder.create_make_tensor_descriptor( + ty._to_ir(_semantic.builder), + base_handle, + [s.handle for s in shape], + [s.handle for s in strides], + padding, + ) + return tensor_descriptor(handle, shape, strides, block_type, layout) diff --git a/third_party/mthreads/python/triton/experimental/gluon/nvidia/__init__.py b/third_party/mthreads/python/triton/experimental/gluon/nvidia/__init__.py new file mode 100644 index 0000000000..8184c7388e --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/nvidia/__init__.py @@ -0,0 +1,4 @@ +from . import hopper +from . import blackwell + +__all__ = ["hopper", "blackwell"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/nvidia/blackwell.py b/third_party/mthreads/python/triton/experimental/gluon/nvidia/blackwell.py new file mode 100644 index 0000000000..abf9198051 --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/nvidia/blackwell.py @@ -0,0 +1,3 @@ +from .hopper import TensorDescriptor + +__all__ = ["TensorDescriptor"] diff --git a/third_party/mthreads/python/triton/experimental/gluon/nvidia/hopper.py b/third_party/mthreads/python/triton/experimental/gluon/nvidia/hopper.py new file mode 100644 index 0000000000..28863722ba --- /dev/null +++ b/third_party/mthreads/python/triton/experimental/gluon/nvidia/hopper.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from typing import List, Any +from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth +from triton.experimental.gluon.language._layouts import NVMMASharedLayout +import triton.language as tl + +__all__ = ["TensorDescriptor"] + + +@dataclass +class TensorDescriptor: + base: Any + shape: List[int] + strides: List[int] + block_shape: List[int] + layout: NVMMASharedLayout + padding: str = "zero" + + def __post_init__(self): + rank = len(self.shape) + assert len(self.strides) == rank, f"rank mismatch: {self}" + assert len(self.block_shape) == rank, f"rank mismatch: {self}" + assert rank > 0, "rank must not be zero" + assert rank <= 5, "rank cannot be more than 5" + assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned" + validate_block_shape(self.block_shape) + dtype_str = canonicalize_dtype(self.base.dtype) + elem_bytes = get_primitive_bitwidth(dtype_str) // 8 + for stride in self.strides[:-1]: + assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned" + for shape_dim in self.shape: + assert shape_dim > 0, "shape must be positive" + assert self.strides[-1] == 1, "Last dimension must be contiguous" + assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout" + assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding" + if self.padding == "nan": + assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors" + assert elem_bytes * 8 == self.layout.element_bitwidth + padding_factor = 2 if self.layout.fp4_padded else 1 + min_block = self.layout.swizzle_byte_width // (elem_bytes * padding_factor) + assert self.block_shape[-1] >= min_block, \ + f"Expected block_shape[-1] to be at least {min_block} but got {self.block_shape[-1]}" + if self.layout.fp4_padded: + assert self.base.data_ptr() % 32 == 0, "For fp4_padded, base must 32-byte aligned" + for stride in self.strides[:-1]: + assert (stride * elem_bytes) % 32 == 0, "For fp4_padded, tensor strides must be 32-byte aligned" + assert tl.target_info.cuda_capability_geq(10, 0), "fp4_padded requires blackwell or newer" + assert not self.layout.fp4_padded or self.layout.swizzle_byte_width == 128, f"FP4 padded operands must be swizzled with 128-byte width, but got {self.layout.swizzle_byte_width}" + assert self.layout.element_bitwidth in [ + 8, 16, 32 + ], f"tensor descriptor dtype must be 8, 16, or 32 bits, but got {self.layout.element_bitwidth}" + + @staticmethod + def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"): + return TensorDescriptor( + tensor, + tensor.shape, + tensor.stride(), + block_shape, + layout, + padding, + ) diff --git a/third_party/mthreads/python/triton/knobs.py b/third_party/mthreads/python/triton/knobs.py new file mode 100644 index 0000000000..35acb984e9 --- /dev/null +++ b/third_party/mthreads/python/triton/knobs.py @@ -0,0 +1,652 @@ +from __future__ import annotations + +import functools +import importlib +import os +import re +import shutil +import subprocess +import sysconfig +import pathlib + +from dataclasses import dataclass +from contextlib import contextmanager +from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union + +from triton._C.libtriton import getenv, getenv_bool # type: ignore + +if TYPE_CHECKING: + from .runtime.cache import CacheManager, RemoteCacheBackend + from .runtime.jit import JitFunctionInfo, KernelParam + from .compiler.compiler import ASTSource, LazyDict, IRSource + + +class Env: + pass + + +env = Env() + +propagate_env: bool = True + + +def setenv(key: str, value: Optional[str]) -> None: + if not propagate_env: + return + + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + +def toenv(val: Any) -> Union[None, tuple[Optional[str]]]: + if val is None: + return (None, ) + + t = type(val) + if t is bool: + return ("1" if val else "0", ) + + if t is str: + return (val, ) + + if t is int: + return (str(val), ) + + return None + + +# There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a +# a string but return an NvidiaTool. +SetType = TypeVar("SetType") +GetType = TypeVar("GetType") + +_NOTHING = object() + + +class env_base(Generic[SetType, GetType]): + + def __init__(self, key: str) -> None: + self.key = key + + def __set_name__(self, objclass: Type[object], name: str) -> None: + self.name = name + + def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType: + py_val = obj.__dict__.get(self.name, _NOTHING) + if py_val is _NOTHING: + return self.get() + return self.transform(py_val) + + def get(self) -> GetType: + raise NotImplementedError() + + def __set__(self, obj: object, value: Union[SetType, Env]) -> None: + if isinstance(value, Env): + obj.__dict__.pop(self.name, None) + else: + obj.__dict__[self.name] = value + if env_val := toenv(value): + setenv(self.key, env_val[0]) + + def __delete__(self, obj: object) -> None: + obj.__dict__.pop(self.name, None) + + def transform(self, val: SetType) -> GetType: + # See comment about GetType/SetType in their definition above. Only needed + # if GetType != SetType. + return cast(GetType, val) + + +class env_str(env_base[str, str]): + + def __init__(self, key: str, default: str): + super().__init__(key) + self.default = default + + def get(self) -> str: + return getenv(self.key, self.default) + + +class env_str_callable_default(env_base[str, str]): + + def __init__(self, key: str, default_factory: Callable[[], str]): + super().__init__(key) + self.default_factory = default_factory + + def get(self) -> str: + env_val = getenv(self.key) + if env_val is None: + return self.default_factory() + return env_val + + +class env_bool(env_base[bool, bool]): + + def __init__(self, key: str, default: bool = False) -> None: + super().__init__(key) + self.default = default + + def get(self) -> bool: + return getenv_bool(self.key, self.default) + + +class env_int(env_base[int, int]): + + def __init__(self, key: str, default: int = 0) -> None: + super().__init__(key) + self.default = default + + def get(self) -> int: + val = getenv(self.key) + if val is None: + return self.default + try: + return int(val) + except ValueError as exc: + raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc + + +ClassType = TypeVar("ClassType") + + +class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]): + + def __init__(self, key: str, type: str) -> None: + super().__init__(key) + # We can't pass the type directly to avoid import cycles + self.type = type + + def get(self) -> Optional[Type[ClassType]]: + val = getenv(self.key) + if val is None: + return None + comps = val.split(":", 1) + if len(comps) != 2: + raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS") + cls = getattr(importlib.import_module(comps[0]), comps[1]) + + if not any((c.__name__ == self.type for c in cls.mro())): + raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'") + + return cast(Type[ClassType], cls) + + +@dataclass +class NvidiaTool: + path: str + version: str + + @staticmethod + @functools.lru_cache + def from_path(path: str) -> Optional[NvidiaTool]: + try: + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is None: + return None + return NvidiaTool(path, version.group(1)) + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +class env_nvidia_tool(env_base[str, NvidiaTool]): + + def __init__(self, binary: str) -> None: + binary += sysconfig.get_config_var("EXE") + self.binary = binary + self.default_path = os.path.join(os.path.dirname(__file__), "backends", "nvidia", "bin", binary) + # Convert ptxas-blackwell to PTXAS_BLACKWELL, not PTXAS-BLACKWELL + super().__init__(f"TRITON_{binary.upper().replace('-', '_')}_PATH") + + def get(self) -> NvidiaTool: + return self.transform(getenv(self.key)) + + def transform(self, path: str) -> NvidiaTool: + # We still add default as fallback in case the pointed binary isn't + # accessible. + if path is not None: + paths = [path, self.default_path] + else: + paths = [self.default_path] + + for path in paths: + if tool := NvidiaTool.from_path(path): + return tool + + raise RuntimeError(f"Cannot find {self.binary}") + + +@dataclass +class MUSATool: + path: str + version: str + + @staticmethod + @functools.lru_cache + def from_path(path: str) -> Optional["MUSATool"]: + if not path: + return None + resolved = pathlib.Path(path).expanduser() + if not resolved.is_file(): + which = shutil.which(str(resolved)) + if which is None: + return None + resolved = pathlib.Path(which) + try: + result = subprocess.check_output([str(resolved), "--version"], stderr=subprocess.STDOUT) + except (subprocess.CalledProcessError, FileNotFoundError, PermissionError, OSError): + return None + version_lines = result.decode("utf-8", errors="replace").splitlines() + version = next((line.strip() for line in version_lines if line.strip()), "") + return MUSATool(str(resolved), version) + + +class env_musa_tool(env_base[str, MUSATool]): + + def __init__(self, key: str, binary: str) -> None: + self.binary = binary + sysconfig.get_config_var("EXE") + super().__init__(key) + + def _candidate_paths(self, path: Optional[str]) -> list[str]: + candidates = [] + if path: + candidates.append(path) + + toolchain_path = getenv("TRITON_MUSA_TOOLCHAIN_PATH") + if toolchain_path: + candidates.append(os.path.join(toolchain_path, self.binary)) + + mtcc_bin_path = getenv("MTCC_BIN_PATH") + if mtcc_bin_path: + candidates.append(os.path.join(mtcc_bin_path, self.binary)) + + musa_home = getenv("MUSA_HOME") or getenv("MUSA_ROOT") + if musa_home: + candidates.append(os.path.join(musa_home, "bin", self.binary)) + + if which := shutil.which(self.binary): + candidates.append(which) + + return candidates + + def get(self) -> MUSATool: + return self.transform(getenv(self.key)) + + def transform(self, path: Optional[str]) -> MUSATool: + for candidate in self._candidate_paths(path): + if tool := MUSATool.from_path(candidate): + return tool + raise RuntimeError(f"Cannot find {self.binary}") + + +# Separate classes so that types are correct +class env_opt_str(env_base[Optional[str], Optional[str]]): + + def get(self) -> Optional[str]: + return getenv(self.key) + + +class env_opt_bool(env_base): + + def get(self) -> Optional[str]: + return getenv_bool(self.key, None) + + +@dataclass(frozen=True) +class CompileTimes: + """ + Model holding timing information for an invocation of the compiler. + + All times in microseconds. + """ + + # Duration of make_ir + ir_initialization: int + + # Ordered mapping from lowering stage to duration spent in that stage. + # Keyed by stage extension, e.g. ttir, ttgir + lowering_stages: list[tuple[str, int]] + + # Duration of saving artifacts/metadata to cache + store_results: int + + @property + def total_lowering(self) -> int: + return sum((stage[1] for stage in self.lowering_stages)) + + @property + def total(self) -> int: + return self.ir_initialization + self.total_lowering + self.store_results + + +class CompilationListener(Protocol): + + def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str], + times: CompileTimes, cache_hit: bool) -> None: + ... + + +knobs_type = TypeVar("knobs_type", bound='base_knobs') + + +class base_knobs: + + @property + def knob_descriptors(self) -> dict[str, env_base]: + return { + k: v + # data descriptors live on the class object + for k, v in type(self).__dict__.items() + if isinstance(v, env_base) + } + + @property + def knobs(self) -> dict[str, Any]: + return {k: getattr(self, k) for k in self.knob_descriptors.keys()} + + def copy(self: knobs_type) -> knobs_type: + res = type(self)() + res.__dict__.update(self.__dict__) + return res + + def reset(self: knobs_type) -> knobs_type: + for knob in self.knob_descriptors.keys(): + delattr(self, knob) + return self + + @contextmanager + def scope(self) -> Generator[None, None, None]: + try: + initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()} + orig = dict(self.__dict__) + yield + finally: + self.__dict__.clear() + self.__dict__.update(orig) + + for k, v in initial_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k] + + +class BuildImpl(Protocol): + + def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], + libraries: list[str], /) -> str: + ... + + +class build_knobs(base_knobs): + """Configuration controlling how the native compiler is invoked""" + cc: env_opt_str = env_opt_str("CC") + + cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH") + cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH") + + impl: Optional[BuildImpl] = None + + @property + def backend_dirs(self) -> set[str]: + return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None} + + +class redis_knobs(base_knobs): + key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + host: env_str = env_str("TRITON_REDIS_HOST", "localhost") + port: env_int = env_int("TRITON_REDIS_PORT", 6379) + + +cache: cache_knobs + + +class cache_knobs(base_knobs): + home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/")) + + dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump")) + override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override")) + dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache")) + + manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager") + remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend") + + def get_triton_dir(self, dirname: str) -> str: + return os.path.join(self.home_dir, ".triton", dirname) + + +class compilation_knobs(base_knobs): + override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE") + dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP") + dump_ir_extract_di_local_variables: env_bool = env_bool("LLVM_EXTRACT_DI_LOCAL_VARIABLES") + store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY") + always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE") + # TODO: Use enum to constrain / 'typecheck' the values + use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC") + enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN") + disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO") + front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING") + allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS") + # Instrumentation mode is checked on every run, which is expensive. + # We cache the value here to avoid the expensive check on every run. + instrumentation_mode: str = env_str("TRITON_INSTRUMENTATION_MODE", "").get() + listener: Union[CompilationListener, None] = None + + +class autotuning_knobs(base_knobs): + cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING") + print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING") + + +class LaunchHook(Protocol): + """Hook invoked before and after kernel launching + """ + + def __call__(self, metadata: LazyDict) -> None: + ... + + +class InitHandleHook(Protocol): + """Hook invoked around kernel binary/module loading. + module/function can be None for the *start* hook (before loading). + """ + + def __call__( + self, + module: Optional[object], + function: Optional[Callable], + name: str, + metadata_group: dict[str, str], + hash: str, + ) -> None: + ... + + +F = TypeVar("F", bound=Callable) + + +class HookChain(Generic[F]): + """A chain of hooks of the same type F to be called in order. + """ + + def __init__(self, reversed: bool = False): + self.calls: list[F] = [] + self.reversed = reversed + + def add(self, func: F) -> None: + if func not in self.calls: + self.calls.append(func) + + def remove(self, func: F) -> None: + if func in self.calls: + self.calls.remove(func) + + def __call__(self, *args, **kwargs): + for call in self.calls if not self.reversed else reversed(self.calls): + call(*args, **kwargs) + + +# This is of the form [attr_name, attr_val] +# TODO: Use tuple instead of list for better typing. +KernelAttr = list[Union[str, int]] + + +class JITHookCompileInfo(TypedDict): + key: str + signature: dict[KernelParam, str] + device: int + constants: None + num_warps: int + num_ctas: int + num_stages: int + enable_fp_fusion: bool + launch_cooperative_grid: bool + extern_libs: tuple[tuple[str, str], ...] + configs: list[dict[tuple[int, ...], list[KernelAttr]]] + specialization_data: str + is_warmup: bool + + +class JITHook(Protocol): + + def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool, + already_compiled: bool) -> Optional[bool]: + ... + + +class PipelineStagesHook(Protocol): + + def __call__(self, stages, options, language, capability): + ... + + +class runtime_knobs(base_knobs): + interpret: env_bool = env_bool("TRITON_INTERPRET") + # debug is on critical path for kernel launches + # avoid repeated reads from env-var by calling get directly + debug: bool = env_bool("TRITON_DEBUG").get() + override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH") + + launch_enter_hook: HookChain[LaunchHook] = HookChain() + launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True) + kernel_load_start_hook: HookChain[InitHandleHook] = HookChain() + kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True) + + # Hook for inspecting compiled functions and modules + jit_cache_hook: Optional[JITHook] = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # jit_cache_hook will always be called before compilation and jit_post_compile_hook after. + jit_post_compile_hook: Optional[JITHook] = None + + # Hook for inspecting compiler pipeline stages + add_stages_inspection_hook: Optional[PipelineStagesHook] = None + + +class language_knobs(base_knobs): + fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT") + default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True) + + +class nvidia_knobs(base_knobs): + cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump") + nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm") + ptxas: env_nvidia_tool = env_nvidia_tool("ptxas") + ptxas_blackwell: env_nvidia_tool = env_nvidia_tool("ptxas-blackwell") + + dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP") + disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT") + ptxas_options: env_opt_str = env_opt_str("PTXAS_OPTIONS") + mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION") + dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG") + + libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH") + libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH") + + +class amd_knobs(base_knobs): + use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS", True) + # Note: This requires use_buffer_ops be true to have any effect + use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True) + # Note: This requires use_buffer_ops be true to have any effect + buffer_ops_analyze_small_tensor_range: env_bool = env_bool("AMDGCN_ANALYZE_SMALL_TENSOR_RANGE", False) + dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP") + libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH") + + # We use strs so that we can have a default value based on other runtime info + use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG") + use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE") + use_async_copy: env_opt_bool = env_opt_bool("TRITON_HIP_USE_ASYNC_COPY") + + scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS") + + # Path to dump MIR files for debugging/analysis + dump_mir: env_opt_str = env_opt_str("TRITON_DUMP_MIR") + # Path to externally-provided MIR files to use instead of generated ones + swap_mir: env_opt_str = env_opt_str("TRITON_SWAP_MIR") + + +class musa_knobs(base_knobs): + toolchain_path: env_opt_str = env_opt_str("TRITON_MUSA_TOOLCHAIN_PATH") + llc_path: env_opt_str = env_opt_str("TRITON_MUSA_LLC_PATH") + lld_path: env_opt_str = env_opt_str("TRITON_MUSA_LLD_PATH") + llc_asm_path: env_opt_str = env_opt_str("TRITON_MUSA_LLC_ASM_PATH") + llc: env_musa_tool = env_musa_tool("TRITON_MUSA_LLC_PATH", "llc") + lld: env_musa_tool = env_musa_tool("TRITON_MUSA_LLD_PATH", "ld.lld") + llc_asm: env_musa_tool = env_musa_tool("TRITON_MUSA_LLC_ASM_PATH", "llc") + llc_options: env_opt_str = env_opt_str("TRITON_MUSA_LLC_OPTIONS") + enable_llc_opt: env_bool = env_bool("TRITON_MUSA_ENABLE_LLC_OPT") + enable_fp8_burst2: env_bool = env_bool("TRITON_MUSA_ENABLE_FP8_BURST2") + enable_llvm_compat: env_bool = env_bool("TRITON_MUSA_ENABLE_LLVM_COMPAT", True) + dump_llir: env_bool = env_bool("TRITON_MUSA_DUMP_LLIR") + dump_muasm: env_bool = env_bool("TRITON_MUSA_DUMP_MUASM") + dump_toolchain_log: env_bool = env_bool("TRITON_MUSA_DUMP_TOOLCHAIN_LOG") + replace_llir: env_opt_str = env_opt_str("TRITON_MUSA_REPLACE_LLIR") + replace_mubin: env_opt_str = env_opt_str("TRITON_MUSA_REPLACE_MUBIN") + libdevice_path: env_opt_str = env_opt_str("TRITON_MUSA_LIBDEVICE_PATH") + + +class proton_knobs(base_knobs): + disable: env_bool = env_bool("TRITON_PROTON_DISABLE", False) + cupti_lib_dir: env_str = env_str( + "TRITON_CUPTI_LIB_PATH", + str(pathlib.Path(__file__).parent.absolute() / "backends" / "nvidia" / "lib" / "cupti")) + profile_buffer_size: env_int = env_int("TRITON_PROFILE_BUFFER_SIZE", 64 * 1024 * 1024) + enable_nvtx: env_bool = env_bool("TRITON_ENABLE_NVTX", True) + # This knob is effective only on Blackwell+ GPUs. + # + # When enabled, the profiling session must start after CUDA driver + # initialization but before the CUDA context is created. + # + # You can ensure this in one of the following ways: + # + # 1) Use the `proton` CLI tool to launch the Python script, e.g.: + # `TRITON_ENABLE_HW_TRACE=1 proton python my_script.py` + # + # 2) Call `proton.start()` immediately after importing Proton, e.g.: + # ```python + # import triton + # import triton.profiler as proton + # triton.knobs.proton.enable_hw_trace = True + # proton.start(hook="triton") + # ``` + enable_hw_trace: env_bool = env_bool("TRITON_ENABLE_HW_TRACE", False) + + +build = build_knobs() +redis = redis_knobs() +cache = cache_knobs() +compilation = compilation_knobs() +autotuning = autotuning_knobs() +runtime = runtime_knobs() +language = language_knobs() +nvidia = nvidia_knobs() +amd = amd_knobs() +musa = musa_knobs() +proton = proton_knobs() + + +def refresh_knobs(): + runtime.debug = env_bool("TRITON_DEBUG").get() + compilation.instrumentation_mode = env_str("TRITON_INSTRUMENTATION_MODE", "").get() diff --git a/third_party/mthreads/python/triton/language/__init__.py b/third_party/mthreads/python/triton/language/__init__.py new file mode 100644 index 0000000000..ee38c67d79 --- /dev/null +++ b/third_party/mthreads/python/triton/language/__init__.py @@ -0,0 +1,360 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + bitonic_merge, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + reduce_or, + sigmoid, + softmax, + sort, + squeeze, + sum, + swizzle2d, + topk, + unsqueeze, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + load_tensor_descriptor, + store_tensor_descriptor, + _experimental_descriptor_load, + _experimental_descriptor_store, + make_tensor_descriptor, + tensor_descriptor, + tensor_descriptor_type, + add, + advance, + arange, + associative_scan, + assume, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + condition, + const, + constexpr, + constexpr_type, + debug_barrier, + device_assert, + device_print, + dot, + dot_scaled, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + gather, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + map_elementwise, + max_constancy, + max_contiguous, + maximum, + minimum, + mul, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + slice, + split, + static_assert, + static_print, + static_range, + store, + sub, + tensor, + to_tensor, + trans, + tuple, + tuple_type, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) +from . import target_info + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "load_tensor_descriptor", + "store_tensor_descriptor", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "make_tensor_descriptor", + "tensor_descriptor", + "abs", + "add", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "assume", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "bitonic_merge", + "block_type", + "broadcast", + "broadcast_to", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "condition", + "const", + "constexpr", + "constexpr_type", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dot_scaled", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "gather", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "map_elementwise", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "mul", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reduce_or", + "reshape", + "rsqrt", + "slice", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "squeeze", + "static_assert", + "static_print", + "static_range", + "store", + "sub", + "sum", + "swizzle2d", + "target_info", + "tensor", + "topk", + "to_tensor", + "trans", + "tuple", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "unsqueeze", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name, c): + from builtins import tuple + + if isinstance(name, tuple): + fields = type(name).__dict__.get("_fields", None) + return tuple_type([str_to_ty(x, c) for x in name], fields) + + if name[0] == "*": + name = name[1:] + const = False + if name[0] == "k": + name = name[1:] + const = True + ty = str_to_ty(name, c) + return pointer_type(element_ty=ty, const=const) + + if name.startswith("tensordesc"): + inner = name.split("<")[1].rstrip(">") + dtype, rest = inner.split("[", maxsplit=1) + block_shape, rest = rest.split("]", maxsplit=1) + block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")] + layout = rest.lstrip(",") + is_gluon = len(layout) + dtype = str_to_ty(dtype, None) + ndim = len(block_shape) + shape_type = tuple_type([int32] * ndim) + # FIXME: Last dim stride should be constexpr(1) + stride_type = tuple_type(([int64] * ndim)) + block = block_type(dtype, block_shape) + if is_gluon: + from triton.experimental.gluon.language._layouts import NVMMASharedLayout, PaddedSharedLayout, SwizzledSharedLayout + from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as nvidia_tensor_descriptor_type + from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor_type as amd_tensor_descriptor_type + layout = eval( + layout, + dict(NVMMASharedLayout=NVMMASharedLayout, PaddedSharedLayout=PaddedSharedLayout, + SwizzledSharedLayout=SwizzledSharedLayout)) + if isinstance(layout, NVMMASharedLayout): + return nvidia_tensor_descriptor_type(block, shape_type, stride_type, layout) + else: + return amd_tensor_descriptor_type(block, shape_type, stride_type, layout) + return tensor_descriptor_type(block, shape_type, stride_type) + + if name.startswith("constexpr"): + return constexpr_type(c) + + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/mthreads/python/triton/language/core.py b/third_party/mthreads/python/triton/language/core.py new file mode 100644 index 0000000000..bc8e8cac7d --- /dev/null +++ b/third_party/mthreads/python/triton/language/core.py @@ -0,0 +1,3561 @@ +from __future__ import annotations + +import math +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps, cached_property +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple +from dataclasses import dataclass +import builtins +from .. import knobs +from ..runtime.jit import JITCallable +import inspect + +from .._C.libtriton import ir +from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth, _tuple_create + +T = TypeVar('T') + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def must_use_result(x, s=True): + """If the result of this function is unused, throw an error.""" + if isinstance(x, str): + return (lambda fn: must_use_result(fn, x)) + x._must_use_result = s + return x + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_semantic" not in kwargs or kwargs["_semantic"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_semantic` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + wrapper.signature = inspect.signature(fn) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _semantic, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.signature = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _semantic=None): + return _semantic.to_tensor(x) + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class base_value: + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + type: base_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError + + +class base_type: + + def __eq__(self, other) -> bool: + raise NotImplementedError("Types must implement __eq__") + + def __ne__(self, other) -> bool: + return not (self == other) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError + + def mangle(self) -> str: + raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + raise NotImplementedError + + +class constexpr_type(base_type): + + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, constexpr_type) and self.value == other.value + + def __repr__(self) -> str: + return f"constexpr_type[{self.value}]" + + def __hash__(self): + return hash(self.value) + + def mangle(self) -> str: + if hasattr(self.value, "mangle"): + val = self.value.mangle() + else: + val = repr(self.value) + return f"c{val}" + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + return + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return constexpr(self.value), cursor + + +class constexpr(base_value): + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + while isinstance(value, constexpr): + value = value.value + self.value = value + self.type = constexpr_type(value) + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __hash__(self): + return hash((self.value, self.type)) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + return + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _unwrap_if_constexpr + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _unwrap_if_constexpr(other)) + + def __radd__(self, other): + return constexpr(_unwrap_if_constexpr(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _unwrap_if_constexpr(other)) + + def __rsub__(self, other): + return constexpr(_unwrap_if_constexpr(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _unwrap_if_constexpr(other)) + + def __mod__(self, other): + return constexpr(self.value % _unwrap_if_constexpr(other)) + + def __rmul__(self, other): + return constexpr(_unwrap_if_constexpr(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _unwrap_if_constexpr(other)) + + def __rtruediv__(self, other): + return constexpr(_unwrap_if_constexpr(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _unwrap_if_constexpr(other)) + + def __rfloordiv__(self, other): + return constexpr(_unwrap_if_constexpr(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _unwrap_if_constexpr(other)) + + def __rgt__(self, other): + return constexpr(_unwrap_if_constexpr(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _unwrap_if_constexpr(other)) + + def __rge__(self, other): + return constexpr(_unwrap_if_constexpr(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _unwrap_if_constexpr(other)) + + def __rlt__(self, other): + return constexpr(_unwrap_if_constexpr(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _unwrap_if_constexpr(other)) + + def __rle__(self, other): + return constexpr(_unwrap_if_constexpr(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _unwrap_if_constexpr(other)) + + def __ne__(self, other): + return constexpr(self.value != _unwrap_if_constexpr(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _unwrap_if_constexpr(other)) + + def logical_and(self, other): + return constexpr(self.value and _unwrap_if_constexpr(other)) + + def __or__(self, other): + return constexpr(self.value | _unwrap_if_constexpr(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _unwrap_if_constexpr(other)) + + def logical_or(self, other): + return constexpr(self.value or _unwrap_if_constexpr(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_unwrap_if_constexpr(other)) + + def __rpow__(self, other): + return constexpr(_unwrap_if_constexpr(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _unwrap_if_constexpr(other)) + + def __lshift__(self, other): + return constexpr(self.value << _unwrap_if_constexpr(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + def __getitem__(self, *args): + args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args)) + return self.value.__getitem__(*args) + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return _tuple_create(o, [_unwrap_if_constexpr(x) for x in o]) + if isinstance(o, tuple): + return tuple([_unwrap_if_constexpr(x) for x in o], o.type) + return o.value if isinstance(o, constexpr) else o + + +def _normalize_tuple(t): + normalized_tuple = _unwrap_if_constexpr(t) + if isinstance(normalized_tuple, (list, builtins.tuple)): + normalized_tuple = tuple(normalized_tuple) + return normalized_tuple + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +# ----------------------- +# dtype +# ----------------------- + + +class dtype(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + self.primitive_bitwidth = get_primitive_bitwidth(name) + self.itemsize = self.primitive_bitwidth // 8 + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other) -> bool: + other = _unwrap_if_constexpr(other) + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + supported_fp8_dtypes = getattr( + getattr(builder, "options", None), + "supported_fp8_storage_dtypes", + getattr(getattr(builder, "options", None), "supported_fp8_dtypes", ()), + ) + if supported_fp8_dtypes and self.name not in supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {supported_fp8_dtypes}') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 + + def mangle(self) -> str: + if self.is_int(): + SIGNED = dtype.SIGNEDNESS.SIGNED + prefix = 'i' if self.int_signedness == SIGNED else 'u' + return prefix + str(self.int_bitwidth) + if self.is_floating(): + return str(self) + if self.is_void(): + return 'V' + return super().mangle() + + def with_element_ty(self, element_ty: dtype): + assert not self.is_block() + return element_ty + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): + element_ty = _unwrap_if_constexpr(element_ty) + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') + self.element_ty = element_ty + self.address_space = address_space + self.const = const + self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def is_const(self): + return self.const + + def __eq__(self, other) -> bool: + other = _unwrap_if_constexpr(other) + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const + + @property + def scalar(self): + return self + + def mangle(self) -> str: + return f"P{self.element_ty.mangle()}" + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> Tuple[int]: + return self.shape + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return block_type(scalar_ty, self.shape) + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + @property + def nbytes(self): + return self.numel * (self.element_ty.primitive_bitwidth // 8) + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = '_'.join(map(str, self.shape)) + return f'{elt}S{shape}S' + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields + + @cached_property + def name(self): + if self.fields is None: + return '[' + ','.join(str(v) for v in self.types) + ']' + return '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle() for ty in self.types) + 'T' + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# tensor +# ----------------------- + + +class tensor(base_value): + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + super().__init__() + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = constexpr(math.prod(self.shape)) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = tuple([constexpr(s) for s in self.shape]) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _semantic=None): + return add(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __radd__(self, other, _semantic=None): + return add(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __sub__(self, other, _semantic=None): + return sub(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __rsub__(self, other, _semantic=None): + return sub(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __mul__(self, other, _semantic=None): + return mul(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __rmul__(self, other, _semantic=None): + return mul(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __truediv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.truediv(self, other) + + @builtin + def __rtruediv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.truediv(other, self) + + @builtin + def __floordiv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.floordiv(self, other) + + @builtin + def __rfloordiv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.floordiv(other, self) + + @builtin + def __mod__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.mod(self, other) + + @builtin + def __rmod__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.mod(other, self) + + # unary operators + @builtin + def __neg__(self, _semantic=None): + return _semantic.minus(self) + + @builtin + def __invert__(self, _semantic=None): + return _semantic.invert(self) + + # bitwise operators + + @builtin + def __and__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.and_(self, other) + + @builtin + def __rand__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.and_(other, self) + + @builtin + def __or__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.or_(self, other) + + @builtin + def __ror__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.or_(other, self) + + @builtin + def __xor__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.xor_(self, other) + + @builtin + def __rxor__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.xor_(other, self) + + @builtin + def __lshift__(self, other, _semantic=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return _semantic.shl(self, other) + + @builtin + def __rlshift__(self, other, _semantic=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + return _semantic.shl(other, self) + + @builtin + def __rshift__(self, other, _semantic=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return _semantic.ashr(self, other) + else: + return _semantic.lshr(self, other) + + @builtin + def __rrshift__(self, other, _semantic=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return _semantic.ashr(other, self) + else: + return _semantic.lshr(other, self) + + # > + @builtin + def __gt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_than(self, other) + + @builtin + def __rgt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_than(other, self) + + # >= + @builtin + def __ge__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_equal(self, other) + + @builtin + def __rge__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_equal(other, self) + + # < + @builtin + def __lt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_than(self, other) + + @builtin + def __rlt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_than(other, self) + + # <= + @builtin + def __le__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_equal(self, other) + + @builtin + def __rle__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_equal(other, self) + + # == + @builtin + def __eq__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.equal(self, other) + + @builtin + def __req__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.equal(other, self) + + @builtin + def __ne__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.not_equal(self, other) + + @builtin + def __rne__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.not_equal(other, self) + + @builtin + def logical_and(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.logical_and(self, other) + + @builtin + def logical_or(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.logical_or(self, other) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _semantic=None): + return _semantic.not_(self) + + @builtin + def __getitem__(self, slices, _semantic=None): + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: + slices = [slices] + if isinstance(slices, tuple): + slices = slices.values + ret = self + for dim, sl in enumerate(slices): + if _unwrap_if_constexpr(sl) is None: + ret = _semantic.expand_dims(ret, dim) + elif isinstance(sl, (builtins.slice, slice)) and all( + _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)): + pass # an unsqueeze + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None): + """ + Alias for :py:func:`tensor.cast`. + """ + return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def gather(self, indices, axis) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def reduce_or(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def _type_for_tuple_values(values, fields=None): + return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields) + + +class tuple(base_value): + + def __init__(self, args: Sequence, type: Optional[tuple_type] = None): + self.values = [i for i in args] + if isinstance(type, tuple_type): + self.type = type + elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple + self.type = tuple_type(type) + else: + self.type = _type_for_tuple_values(self.values) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + def __getattr__(self, name): + fields = self.type.fields + if fields is None or name not in fields: + raise AttributeError(f"'tuple' object has no attribute {name}") + return self.values[fields.index(name)] + + # TODO: remove + def _setitem(self, idx, value): + idx = _unwrap_if_constexpr(idx) + assert isinstance(idx, int) + self.values[idx] = value + self.type = _type_for_tuple_values(self.values, self.type.fields) + + def __add__(self, other): + other = _normalize_tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + other = _normalize_tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + v._flatten_ir(handles) + + def __repr__(self): + return f"({', '.join(repr(x) for x in self.values)})" + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + value = tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}" + + +class tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, block_type: block_type): + """Not called by user code.""" + super().__init__() + + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return _semantic.descriptor_load(self, offsets, "", "") + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return _semantic.descriptor_store(self, value, offsets) + + @builtin + def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_add(self, value, offsets) + + @builtin + def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_min(self, value, offsets) + + @builtin + def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_max(self, value, offsets) + + @builtin + def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_and(self, value, offsets) + + @builtin + def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_or(self, value, offsets) + + @builtin + def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_xor(self, value, offsets) + + @builtin + def gather(self, *args, _semantic=None) -> tensor: + """Gather multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "") + + @builtin + def scatter(self, value, *args, _semantic=None) -> tensor: + """Scatter multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return _semantic.descriptor_scatter(self, value, x_offsets, y_offset) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + super()._flatten_ir_types(builder, out) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class tensor_descriptor(tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + # Global shape + self.shape = tuple(shape) + self.strides = tuple(strides) + self.type = tensor_descriptor_type( + block_type, + shape_type=self.shape.type, + strides_type=self.strides.type, + ) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + +# ----------------------- +# aggregate +# ----------------------- + + +@dataclass(frozen=True) +class _aggregate_type(base_type): + """A generic base type for all Triton aggregate types. + + This class contains a reference to the original user-defined Python class + and a list of class fields with their Triton types. + """ + + base_cls: type + fields: List[Tuple[str, base_type]] + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]: + instance = self.base_cls._get_instance() + for name, ty in self.fields: + value, cursor = ty._unflatten_ir(handles, cursor) + setattr(instance, name, value) + return instance, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + for name, ty in self.fields: + ty._flatten_ir_types(builder, out) + + def mangle(self) -> str: + name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}" + fields = [ty.mangle() for (name, ty) in self.fields] + return f"{name}<{', '.join(fields)}>" + + +def _aggregate(cls): + + # Define the wrapped Triton value type. + class aggregate_value(base_value): + __triton_builtin__ = True + __triton_aggregate__ = True + + @classmethod + def _get_instance(this_cls): + return super().__new__(this_cls) + + def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs): + # Call into the user-defined constructor. + instance = this_cls._get_instance() + extra_kwargs = {} + if isinstance(cls.__init__, JITCallable): + # raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function") + pass + else: + if "_semantic" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_semantic"] = _semantic + if "_generator" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_generator"] = _generator + cls.__init__(instance, *args, **extra_kwargs, **kwargs) + + # Require that the user-defined constructor initialized all fields. + for name in cls.__annotations__.keys(): + if not hasattr(instance, name): + raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'") + + return instance + + # Only allow setting attributes defined in the class annotations. + def __setattr__(self, name, value): + if name not in cls.__annotations__: + raise AttributeError(f"{cls.__name__} has no attribute '{name}'") + if not isinstance(value, cls.__annotations__[name]): + raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}") + super().__setattr__(name, value) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + for name in cls.__annotations__.keys(): + getattr(self, name)._flatten_ir(handles) + + @property + def type(self): + return _aggregate_type(aggregate_value, + [(name, getattr(self, name).type) for name in cls.__annotations__.keys()]) + + hash_attrs = [cls.__init__] + + for (name, member) in inspect.getmembers(cls): + if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable): + if name != "__init__": + setattr(aggregate_value, name, member) + hash_attrs.append(member) + + aggregate_value.hash_attrs = hash_attrs + aggregate_value.__name__ = cls.__name__ + aggregate_value.__module__ = cls.__module__ + aggregate_value.__qualname__ = cls.__qualname__ + aggregate_value.__doc__ = cls.__doc__ + + return aggregate_value + + +# ----------------------- +# SPMD Programming Model +# ----------------------- + + +@builtin +def program_id(axis, _semantic=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = _semantic.program_id(0) + # pid1 = _semantic.program_id(1) + # pid2 = _semantic.program_id(2) + # npg0 = _semantic.num_programs(0) + # npg1 = _semantic.num_programs(1) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _unwrap_if_constexpr(axis) + return _semantic.program_id(axis) + + +@builtin +def num_programs(axis, _semantic=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _unwrap_if_constexpr(axis) + return _semantic.num_programs(axis) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _semantic=None): + start = _unwrap_if_constexpr(start) + end = _unwrap_if_constexpr(end) + return _semantic.arange(start, end) + + +arange.__doc__ = f""" + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 +""" + + +def _unwrap_shape(shape): + shape = _unwrap_if_constexpr(shape) + return [_unwrap_if_constexpr(s) for s in shape] + + +def _shape_check_impl(shape): + shape = _unwrap_shape(shape) + validate_block_shape(shape) + return shape + + +@builtin +def full(shape, value, dtype, _semantic=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param value: A scalar value to fill the array with + :type value: scalar + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype + """ + shape = _shape_check_impl(shape) + value = _unwrap_if_constexpr(value) + dtype = _unwrap_if_constexpr(dtype) + return _semantic.full(shape, value, dtype) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _semantic=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return _semantic.broadcast_impl_value(input, other) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _semantic=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return _semantic.broadcast_impl_shape(input, shape) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _semantic=None): + """ + Permutes the dimensions of a tensor. + + If the parameter :code:`dims` is not specified, the function defaults to + swapping the last two axes, thereby performing an (optionally batched) + 2D transpose. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + dims = _unwrap_iterable(dims) + if not dims: + n = len(input.shape) + if n < 2: + raise ValueError("tl.trans invoked with a 0- or 1-dimensional tensor") + dims = list(builtins.range(n - 2)) + [n - 1, n - 2] + return _semantic.permute(input, dims) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _semantic=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to swap the last two axes. + """ + dims = _unwrap_iterable(dims) + return _semantic.permute(input, dims) + + +@builtin +def cat(input, other, can_reorder=False, dim=0, _semantic=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: Tensor + :param other: The second input tensor. + :type other: Tensor + :param can_reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops). + :type can_reorder: bool + :param dim: The dimension to concatenate along (used when can_reorder is False). + :type dim: int + """ + if can_reorder: + return _semantic.cat(input, other, can_reorder) + + rank = len(input.shape) + assert rank == len(other.shape), f"tensors must have the same rank, got {rank} and {len(other.shape)}" + dim = _wrap_axis(_unwrap_if_constexpr(dim), rank) + assert all(input.shape[i] == other.shape[i] for i in builtins.range(rank) if i != + dim), f"tensor dims must match except in the concat dimension {dim}, got {input.shape} and {other.shape}" + + # Join introduces a new minor dim; move it before the concat dim and merge. + c = join(input, other, _semantic=_semantic) + order = list(builtins.range(rank)) + order.insert(dim, rank) + c = permute(c, order, _semantic=_semantic) + new_shape = list(input.shape) + new_shape[dim] = input.shape[dim] + other.shape[dim] + return reshape(c, new_shape, _semantic=_semantic) + + +@builtin +def join(a, b, _semantic=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return _semantic.join(a, b) + + +def _unsplat(x, _semantic=None, _generator=None): + """ + Convert a single-element tensor to a scalar. + """ + if len(x.shape) == 0: + return x + numel = 1 + for d in x.shape: + numel *= d + assert numel == 1, "can only unsplat single-element tensors" + return _semantic.unsplat(x) + + +@_tensor_member_fn +@builtin +def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But _semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = _semantic.expand_dims(a, 0) + + out_lhs, out_rhs = _semantic.split(a) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator) + out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _semantic=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return _semantic.reshape(input, shape, can_reorder=True) + + +@_tensor_member_fn +@builtin +def item(input, _semantic=None, _generator=None): + """ + Converts a single-element tensor into a scalar. + """ + return _unsplat(input, _semantic=_semantic, _generator=_generator) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + if len(shape) == 0: + return _unsplat(input, _semantic=_semantic, _generator=_generator) + return _semantic.reshape(input, shape, can_reorder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _semantic=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _semantic.to_tensor(input) + axis = _unwrap_if_constexpr(axis) + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = _semantic.expand_dims(ret, a) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: tl.dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + """ + input = _semantic.to_tensor(input) + dtype = _unwrap_if_constexpr(dtype) + fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding) + bitcast = _unwrap_if_constexpr(bitcast) + if bitcast: + return _semantic.bitcast(input, dtype) + return _semantic.cast(input, dtype, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _semantic=None): + """ + Returns the matrix product of two blocks. + + The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions. + For three-dimensional blocks, `tl.dot` performs the batched matrix product, + where the first dimension of each block represents the batch dimension. + + :param input: The first tensor to be multiplied. + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions + input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and + (allow_tf32 or allow_tf32 is None)) else "ieee") + + input_precision = _unwrap_if_constexpr(input_precision) + out_dtype = _unwrap_if_constexpr(out_dtype) + max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc) + acc = _unwrap_if_constexpr(acc) + + # check shapes make sense: + a_shape = list(input.shape) + b_shape = list(other.shape) + assert len(a_shape) == len(b_shape) >= 2, "input and other must have equal ranks >= 2" + assert a_shape[:-2] == b_shape[:-2], "input and other must have equal batch shapes" + assert a_shape[-1] == b_shape[-2], "input and other must have equal reduction dimensions" + + # compute shape of accumulator: + c_shape = a_shape[:-1] + [b_shape[-1]] + if acc is not None: + assert list(acc.shape) == c_shape, "accumulator shape is incompatible" + rank = len(c_shape) + + if rank >= 4: + batch_size = 1 + for i in builtins.range(rank - 2): + batch_size *= c_shape[i] + input = _semantic.reshape(input, [batch_size] + a_shape[-2:], can_reorder=False) + other = _semantic.reshape(other, [batch_size] + b_shape[-2:], can_reorder=False) + if acc is not None: + acc = _semantic.reshape(acc, [batch_size] + c_shape[-2:], can_reorder=False) + + res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype) + + if rank >= 4: + res = _semantic.reshape(res, c_shape, can_reorder=False) + + assert list(res.shape) == c_shape, "output shape is unexpected" + return res + + +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, + rhs_k_pack=True, out_dtype=float32, _semantic=None): + """ + Returns the matrix product of two blocks in microscaling format. + + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + + Software emulation enables targeting hardware architectures without native microscaling + operation support. Right now for such case, microscaled lhs/rhs are upcasted to + :code:`bf16` element type beforehand for dot computation, with one exception: + for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type, + the other input is also upcasted to :code:`fp16` element type instead. + This behavior is experimental and may be subject to change in the future. + + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`. + :type lhs_scale: e8m0 type represented as an uint8 tensor, or None. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type lhs_format: str + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param rhs_scale: Scale factor for rhs tensor. Shape should be [N, K//group_size] where rhs is [K, N]. + Important: Do NOT transpose rhs_scale + :type rhs_scale: e8m0 type represented as an uint8 tensor, or None. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type rhs_format: str + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension. + :type lhs_k_pack: bool, optional + :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension. + :type rhs_k_pack: bool, optional + """ + out_dtype = _unwrap_if_constexpr(out_dtype) + acc = _unwrap_if_constexpr(acc) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack, + rhs_k_pack, out_dtype) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _semantic=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for + cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1), + and ".cv" means don’t cache and fetch again. see + `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _unwrap_if_constexpr(mask) + other = _unwrap_if_constexpr(other) + if mask is not None: + mask = _semantic.to_tensor(mask) + if other is not None: + other = _semantic.to_tensor(other) + padding_option = _unwrap_if_constexpr(padding_option) + cache_modifier = _unwrap_if_constexpr(cache_modifier) + eviction_policy = _unwrap_if_constexpr(eviction_policy) + volatile = _unwrap_if_constexpr(volatile) + return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile) + + +@builtin +def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], + _semantic=None) -> tensor: + """Load a block of data from a tensor descriptor.""" + return desc.load(offsets, _semantic=_semantic) + + +@builtin +def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor, + _semantic=None) -> tensor: + """Store a block of data to a tensor descriptor.""" + return desc.store(offsets, value, _semantic=_semantic) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _semantic=None): + """ + Legacy compatibility API for descriptor load. + + New code should prefer `load_tensor_descriptor`. We keep this symbol so + migrated tests can exercise the same backend descriptor path without + monkeypatching triton.language internals in conftest. + """ + # Legacy signature includes shape; 3.5 descriptors already encode block + # shape in the descriptor type, so shape is ignored. + _ = shape + dtype = _unwrap_if_constexpr(dtype) + value = desc_pointer.load(offsets, _semantic=_semantic) + if value.dtype == dtype: + return value + if value.dtype.primitive_bitwidth == dtype.primitive_bitwidth: + return value.to(dtype, bitcast=True, _semantic=_semantic) + return value.to(dtype, _semantic=_semantic) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _semantic=None): + """ + Legacy compatibility API for descriptor store. + + New code should prefer `store_tensor_descriptor`. + """ + value = _semantic.to_tensor(value) + desc_dtype = desc_pointer.dtype + if value.dtype != desc_dtype and value.dtype.primitive_bitwidth == desc_dtype.primitive_bitwidth: + value = value.to(desc_dtype, bitcast=True, _semantic=_semantic) + return desc_pointer.store(offsets, value, _semantic=_semantic) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for + cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt" + stands for cache write-through, see `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"} + """ + # `value` can be constexpr + value = _semantic.to_tensor(value) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + cache_modifier = _unwrap_if_constexpr(cache_modifier) + eviction_policy = _unwrap_if_constexpr(eviction_policy) + return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order) + + +@must_use_result( + "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable." +) +@_tensor_member_fn +@builtin +def advance(base, offsets, _semantic=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return _semantic.advance(base, offsets) + + +@builtin +def make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + padding_option="zero", + _semantic=None, +) -> tensor_descriptor: + """Make a tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M / M_BLOCK, N / N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + + padding_option = _unwrap_if_constexpr(padding_option) + return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None): + cmp = _semantic.to_tensor(cmp) + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + return _semantic.atomic_cas(pointer, cmp, val, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_xchg(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_add(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_max(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_min(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_and(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_or(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_xor(pointer, val, mask, sem, scope) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _semantic=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _semantic.to_tensor(condition) + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.where(condition, x, y) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.add(x, y, sanitize_overflow) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.sub(x, y, sanitize_overflow) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.mul(x, y, sanitize_overflow) + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + y = _promote_bfloat16_to_float32(y, _semantic=_semantic) + propagate_nan = _unwrap_if_constexpr(propagate_nan) + return _semantic.minimum(x, y, propagate_nan) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + y = _promote_bfloat16_to_float32(y, _semantic=_semantic) + propagate_nan = _unwrap_if_constexpr(propagate_nan) + return _semantic.maximum(x, y, propagate_nan) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + min = _semantic.to_tensor(min) + max = _semantic.to_tensor(max) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + min = _promote_bfloat16_to_float32(min, _semantic=_semantic) + max = _promote_bfloat16_to_float32(max, _semantic=_semantic) + + propagate_nan = _unwrap_if_constexpr(propagate_nan) + + return _semantic.clamp(x, min, max, propagate_nan) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None, + dtype_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value + :type {return_indices_arg}: bool""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN + :type {tie_break_arg}: bool""" + if dtype_arg is not None: + docstr += f""" + :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`. + :type {dtype_arg}: tl.dtype""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int | None + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0] + + def make_combine_region(reduce_op): + param_types = [t.type.scalar for t in input] * 2 + region = reduce_op.get_region(0) + builder = _semantic.builder + with _insertion_guard(builder): + to_ir = lambda T: T.to_ir(builder) + block = builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _semantic=_semantic) + return t + + axis = _unwrap_if_constexpr(axis) + keep_dims = _unwrap_if_constexpr(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = _semantic.reduction(input, axis, make_combine_region) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _semantic=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _semantic=_semantic) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None): + axis = _unwrap_if_constexpr(axis) + n = input.shape[axis] + index = arange(0, n, _semantic=_semantic) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _semantic=_semantic) + index = broadcast_to(index, input.shape, _semantic=_semantic) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the scan should be done + :type axis: int + :param reverse: if true, the scan is performed in the reverse direction + :type reverse: bool""" + + if dtype_arg is not None: + docstr += f""" + :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`. + :type {dtype_arg}: tl.dtype""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done + :type axis: int + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param reverse: whether to apply the associative scan in the reverse direction along axis + :type reverse: bool + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0] + + def make_combine_region(scan_op): + param_types = [t.type.scalar for t in input] * 2 + region = scan_op.get_region(0) + builder = _semantic.builder + with _insertion_guard(builder): + to_ir = lambda T: T.to_ir(builder) + block = builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + builder.create_scan_ret(*handles) + + axis = _unwrap_if_constexpr(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return _semantic.associative_scan(input, axis, make_combine_region, reverse) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, mask=None, _semantic=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :type input: Tensor + :param num_bins: number of histogram bins + :type num_bins: int + :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram + :type mask: Block of `triton.int1`, optional + + """ + num_bins = _unwrap_if_constexpr(num_bins) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + return _semantic.histogram(input, num_bins, mask) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _semantic=None): + """Gather from a tensor along a given dimension. + + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + + """ + src = _unwrap_if_constexpr(src) + index = _unwrap_if_constexpr(index) + axis = _unwrap_if_constexpr(axis) + return _semantic.gather(src, index, axis) + + +@builtin +def map_elementwise( + scalar_fn: Callable[..., Tuple[tensor, ...]], + *args: tensor, + pack=1, + _semantic=None, + _generator=None, +): + ''' + Map a scalar function over a tensor. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + This may be useful in allowing control flow over single elements in a tensor, + for example a multi-branch function where one branch is more expensive. With + :code:`tl.where` you are forced to calculate both sides of the branch, but + with an if we only execute one side. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def selu_scalar(x, alpha): + if x > 0: + return a + else: + return alpha * (tl.exp(x) - 1) + + @triton.jit + def selu(x, alpha): + return tl.map_elementwise(selu_scalar, x, alpha) + + :param scalar_fn: the function to map over. + :param pack: the number of elements to be processed by one function call. + :return: one tensor or a tuple of tensors, depending on the mapped function. + ''' + # Build the block for the nested region first to discover the return types + assert pack >= 1 + in_scalar_tys = [t.type.scalar for t in args] + builder = _semantic.builder + block = builder.new_block() + scalar_args = [] + original_loc = builder.get_loc() + for i, ty in enumerate(in_scalar_tys): + for j in builtins.range(pack): + block.add_argument_at(ty.to_ir(builder), original_loc) + scalar_args.append(tensor(block.arg(i * pack + j), ty)) + + with _insertion_guard(builder): + builder.set_insertion_point_to_start(block) + scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={}) + + is_single = isinstance(scalar_results, tensor) + if is_single: + scalar_results = scalar_results, + + handles = [r.handle for r in scalar_results] + builder.set_loc(original_loc) + builder.create_map_elementwise_ret(handles) + + fn_result_types = [x.type for x in scalar_results] + scalar_result_types = fn_result_types + if pack > 1: + scalar_result_types = fn_result_types[::pack] + for offset in builtins.range(1, pack): + assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results" + + def make_elementwise_region(elementwise_op): + region = elementwise_op.get_region(0) + region.push_back(block) + + builder.set_loc(original_loc) + result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region) + return result[0] if is_single else result + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_semantic=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return _semantic.debug_barrier() + + +@builtin +def multiple_of(input, values, _semantic=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _semantic=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _semantic=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.max_constancy(input, values) + + +@builtin +def assume(cond, _semantic=None): + ''' + Allow compiler to assume the :code:`cond` is True. + ''' + return _semantic.assume(_semantic.to_tensor(cond)) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _semantic=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _semantic=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _unwrap_if_constexpr(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_semantic.to_tensor(arg)) + return _semantic.device_print(prefix, new_args, hex) + + +@builtin +def device_assert(cond, msg="", mask=None, _semantic=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _unwrap_if_constexpr(msg) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _semantic=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + `PTX `_ + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + `LLVM format `_ + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _unwrap_if_constexpr(asm) + constraints = _unwrap_if_constexpr(constraints) + pack = _unwrap_if_constexpr(pack) + is_pure = _unwrap_if_constexpr(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_semantic.to_tensor(arg) for arg in args]: + bin_op_type_checking = partial( + _semantic.binary_op_type_checking_impl, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype] + handles = [t.handle for t in dispatch_args] + builder = _semantic.builder + call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range(base_value): + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr" + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr" + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr" + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range(base_value): + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + :param warp_specialize: Enable automatic warp specialization on the loop. + The compiler will attempt to partition memory, MMA, and vector + operations in the loop into separate async partitions. This will + increase the total number of warps required by the kernel. + :param disable_licm: Tells the compiler it shouldn't hoist loop invariant + code outside the loop. This is often useful to avoid creating long liveranges + within a loop. + + Note that warp specialization is only supported on Blackwell GPUs and + only works on simple matmul loops. Support for arbitrary loops will be + expanded over time. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + self.warp_specialize = warp_specialize + self.disable_licm = disable_licm + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +class condition(base_value): + """ + While loop condition wrapper. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + while tl.condition(c, disable_licm) + ... + :note: This is a special wrapper used to annotate while loops in the context of + :code:`triton.jit` functions. It allows user to pass extra attributes to the compiler. + :param disable_licm: Tells the compiler it shouldn't hoist loop invariant + code outside the loop. This is often useful to avoid creating long liveranges + within a loop. + """ + + def __init__(self, arg1, disable_licm=False): + self.condition = arg1 + self.disable_licm = disable_licm + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool, + _semantic): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_type: the type of the return value + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + builder = _semantic.builder + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _semantic=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _semantic.to_tensor(dispatch_args[i]) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + + arg_types = tuple(arg_types) + ret_type = arg_type_symbol_dict[arg_types][1] + if len(arg_types) > 0: + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_type = broadcast_arg.type.with_element_ty(ret_type) + func = _semantic.builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic) + + +def binary_op_type_legalization(lhs, rhs, semantic): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) + + +_NOTHING = object() + + +def is_negative_zero(x): + return x == 0.0 and math.copysign(1.0, x) < 0 + + +@builtin +def builtin_max(*args, propagate_nan=_NOTHING, _semantic=None): + args = _unwrap_if_constexpr(args) + is_constexpr = all(not isinstance(x, base_value) for x in args) + if is_constexpr: + assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin max" + assert not any(math.isnan(x) for x in args) + assert not any(is_negative_zero(x) for x in args) + return constexpr(builtins.max(_unwrap_if_constexpr(args))) + + if propagate_nan is _NOTHING: + propagate_nan = PropagateNan.NONE + else: + warn("passing propagate_nan to builtin max is deprecated, use tl.minimum instead") + + assert len(args) >= 2, "min requires at least 2 values" + max_val = args[0] + for arg in args[1:]: + max_val = maximum(max_val, arg, propagate_nan=propagate_nan, _semantic=_semantic) + if max_val.type.is_block(): + warn("builtin max on non-scalar tensor values is deprecated, use tl.maximum instead") + return max_val + + +@builtin +def builtin_min(*args, propagate_nan=_NOTHING, _semantic=None): + args = _unwrap_if_constexpr(args) + is_constexpr = all(not isinstance(x, base_value) for x in args) + if is_constexpr: + assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin min" + assert not any(math.isnan(x) for x in args) + assert not any(is_negative_zero(x) for x in args) + return constexpr(builtins.min(_unwrap_if_constexpr(args))) + + if propagate_nan is _NOTHING: + propagate_nan = PropagateNan.NONE + else: + warn("passing propagate_nan to builtin min is deprecated, use tl.minimum instead") + + assert len(args) >= 2, "min requires at least 2 values" + min_val = args[0] + for arg in args[1:]: + min_val = minimum(min_val, arg, propagate_nan=propagate_nan, _semantic=_semantic) + if min_val.type.is_block(): + warn("builtin min on non-scalar tensor values is deprecated, use tl.minimum instead") + return min_val diff --git a/third_party/mthreads/python/triton/language/extra/__init__.py b/third_party/mthreads/python/triton/language/extra/__init__.py new file mode 100644 index 0000000000..3f8c70a716 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/__init__.py @@ -0,0 +1,26 @@ +import pkgutil +from importlib.util import module_from_spec +from sys import modules + +_backends = [] +for module_finder, module_name, is_pkg in pkgutil.iter_modules( + __path__, + prefix=__name__ + ".", +): + # skip .py files (like libdevice.py) + if not is_pkg: + continue + + # import backends (like cuda and hip) that are included during setup.py + spec = module_finder.find_spec(module_name) + if spec is None or spec.loader is None: + continue + module = module_from_spec(spec) + spec.loader.exec_module(module) + + _backends.append(module_name) + modules[module_name] = module + +__all__ = _backends + +del _backends diff --git a/third_party/mthreads/python/triton/language/extra/libdevice.py b/third_party/mthreads/python/triton/language/extra/libdevice.py new file mode 100644 index 0000000000..e29810bfba --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/libdevice.py @@ -0,0 +1,790 @@ +def clz(arg0): + ... + + +def popc(arg0): + ... + + +def byte_perm(arg0, arg1, arg2): + ... + + +def mulhi(arg0, arg1): + ... + + +def mul24(arg0, arg1): + ... + + +def brev(arg0): + ... + + +def sad(arg0, arg1, arg2): + ... + + +def abs(arg0): + ... + + +def floor(arg0): + ... + + +def rcp64h(arg0): + ... + + +def rsqrt(arg0): + ... + + +def ceil(arg0): + ... + + +def trunc(arg0): + ... + + +def exp2(arg0): + ... + + +def saturatef(arg0): + ... + + +def fma_rn(arg0, arg1, arg2): + ... + + +def fma_rz(arg0, arg1, arg2): + ... + + +def fma_rd(arg0, arg1, arg2): + ... + + +def fma_ru(arg0, arg1, arg2): + ... + + +def fast_dividef(arg0, arg1): + ... + + +def div_rn(arg0, arg1): + ... + + +def div_rz(arg0, arg1): + ... + + +def div_rd(arg0, arg1): + ... + + +def div_ru(arg0, arg1): + ... + + +def rcp_rn(arg0): + ... + + +def rcp_rz(arg0): + ... + + +def rcp_rd(arg0): + ... + + +def rcp_ru(arg0): + ... + + +def sqrt_rn(arg0): + ... + + +def sqrt_rz(arg0): + ... + + +def sqrt_rd(arg0): + ... + + +def sqrt_ru(arg0): + ... + + +def sqrt(arg0): + ... + + +def add_rn(arg0, arg1): + ... + + +def add_rz(arg0, arg1): + ... + + +def add_rd(arg0, arg1): + ... + + +def add_ru(arg0, arg1): + ... + + +def mul_rn(arg0, arg1): + ... + + +def mul_rz(arg0, arg1): + ... + + +def mul_rd(arg0, arg1): + ... + + +def mul_ru(arg0, arg1): + ... + + +def double2float_rn(arg0): + ... + + +def double2float_rz(arg0): + ... + + +def double2float_rd(arg0): + ... + + +def double2float_ru(arg0): + ... + + +def double2int_rn(arg0): + ... + + +def double2int_rz(arg0): + ... + + +def double2int_rd(arg0): + ... + + +def double2int_ru(arg0): + ... + + +def double2uint_rn(arg0): + ... + + +def double2uint_rz(arg0): + ... + + +def double2uint_rd(arg0): + ... + + +def double2uint_ru(arg0): + ... + + +def int2double_rn(arg0): + ... + + +def uint2double_rn(arg0): + ... + + +def float2int_rn(arg0): + ... + + +def float2int_rz(arg0): + ... + + +def float2int_rd(arg0): + ... + + +def float2int_ru(arg0): + ... + + +def float2uint_rn(arg0): + ... + + +def float2uint_rz(arg0): + ... + + +def float2uint_rd(arg0): + ... + + +def float2uint_ru(arg0): + ... + + +def int2float_rn(arg0): + ... + + +def int2float_rz(arg0): + ... + + +def int2float_rd(arg0): + ... + + +def int2float_ru(arg0): + ... + + +def uint2float_rn(arg0): + ... + + +def uint2float_rz(arg0): + ... + + +def uint2float_rd(arg0): + ... + + +def uint2float_ru(arg0): + ... + + +def hiloint2double(arg0, arg1): + ... + + +def double2loint(arg0): + ... + + +def double2hiint(arg0): + ... + + +def float2ll_rn(arg0): + ... + + +def float2ll_rz(arg0): + ... + + +def float2ll_rd(arg0): + ... + + +def float2ll_ru(arg0): + ... + + +def float2ull_rn(arg0): + ... + + +def float2ull_rz(arg0): + ... + + +def float2ull_rd(arg0): + ... + + +def float2ull_ru(arg0): + ... + + +def double2ll_rn(arg0): + ... + + +def double2ll_rz(arg0): + ... + + +def double2ll_rd(arg0): + ... + + +def double2ll_ru(arg0): + ... + + +def double2ull_rn(arg0): + ... + + +def double2ull_rz(arg0): + ... + + +def double2ull_rd(arg0): + ... + + +def double2ull_ru(arg0): + ... + + +def ll2float_rn(arg0): + ... + + +def ll2float_rz(arg0): + ... + + +def ll2float_rd(arg0): + ... + + +def ll2float_ru(arg0): + ... + + +def ull2float_rn(arg0): + ... + + +def ull2float_rz(arg0): + ... + + +def ull2float_rd(arg0): + ... + + +def ull2float_ru(arg0): + ... + + +def ll2double_rn(arg0): + ... + + +def ll2double_rz(arg0): + ... + + +def ll2double_rd(arg0): + ... + + +def ll2double_ru(arg0): + ... + + +def ull2double_rn(arg0): + ... + + +def ull2double_rz(arg0): + ... + + +def ull2double_rd(arg0): + ... + + +def ull2double_ru(arg0): + ... + + +def int_as_float(arg0): + ... + + +def float_as_int(arg0): + ... + + +def uint_as_float(arg0): + ... + + +def float_as_uint(arg0): + ... + + +def longlong_as_double(arg0): + ... + + +def double_as_longlong(arg0): + ... + + +def fast_sinf(arg0): + ... + + +def fast_cosf(arg0): + ... + + +def fast_log2f(arg0): + ... + + +def fast_logf(arg0): + ... + + +def fast_expf(arg0): + ... + + +def fast_tanhf(arg0): + ... + + +def fast_tanf(arg0): + ... + + +def fast_exp10f(arg0): + ... + + +def fast_log10f(arg0): + ... + + +def fast_powf(arg0, arg1): + ... + + +def hadd(arg0, arg1): + ... + + +def rhadd(arg0, arg1): + ... + + +def sub_rn(arg0, arg1): + ... + + +def sub_rz(arg0, arg1): + ... + + +def sub_rd(arg0, arg1): + ... + + +def sub_ru(arg0, arg1): + ... + + +def rsqrt_rn(arg0): + ... + + +def ffs(arg0): + ... + + +def rint(arg0): + ... + + +def llrint(arg0): + ... + + +def nearbyint(arg0): + ... + + +def isnan(arg0): + ... + + +def signbit(arg0): + ... + + +def copysign(arg0, arg1): + ... + + +def finitef(arg0): + ... + + +def isinf(arg0): + ... + + +def nextafter(arg0, arg1): + ... + + +def sin(arg0): + ... + + +def cos(arg0): + ... + + +def sinpi(arg0): + ... + + +def cospi(arg0): + ... + + +def tan(arg0): + ... + + +def log2(arg0): + ... + + +def exp(arg0): + ... + + +def exp10(arg0): + ... + + +def cosh(arg0): + ... + + +def sinh(arg0): + ... + + +def tanh(arg0): + ... + + +def atan2(arg0, arg1): + ... + + +def atan(arg0): + ... + + +def asin(arg0): + ... + + +def acos(arg0): + ... + + +def log(arg0): + ... + + +def log10(arg0): + ... + + +def log1p(arg0): + ... + + +def acosh(arg0): + ... + + +def asinh(arg0): + ... + + +def atanh(arg0): + ... + + +def expm1(arg0): + ... + + +def hypot(arg0, arg1): + ... + + +def rhypot(arg0, arg1): + ... + + +def norm3d(arg0, arg1, arg2): + ... + + +def rnorm3d(arg0, arg1, arg2): + ... + + +def norm4d(arg0, arg1, arg2, arg3): + ... + + +def rnorm4d(arg0, arg1, arg2, arg3): + ... + + +def cbrt(arg0): + ... + + +def rcbrt(arg0): + ... + + +def j0(arg0): + ... + + +def j1(arg0): + ... + + +def y0(arg0): + ... + + +def y1(arg0): + ... + + +def yn(arg0, arg1): + ... + + +def jn(arg0, arg1): + ... + + +def cyl_bessel_i0(arg0): + ... + + +def cyl_bessel_i1(arg0): + ... + + +def erf(arg0): + ... + + +def erfinv(arg0): + ... + + +def erfc(arg0): + ... + + +def erfcx(arg0): + ... + + +def erfcinv(arg0): + ... + + +def normcdfinv(arg0): + ... + + +def normcdf(arg0): + ... + + +def lgamma(arg0): + ... + + +def ldexp(arg0, arg1): + ... + + +def scalbn(arg0, arg1): + ... + + +def fmod(arg0, arg1): + ... + + +def remainder(arg0, arg1): + ... + + +def fma(arg0, arg1, arg2): + ... + + +def pow(arg0, arg1): + ... + + +def tgamma(arg0): + ... + + +def round(arg0): + ... + + +def llround(arg0): + ... + + +def fdim(arg0, arg1): + ... + + +def ilogb(arg0): + ... + + +def logb(arg0): + ... + + +def isfinited(arg0): + ... diff --git a/third_party/mthreads/python/triton/language/math.py b/third_party/mthreads/python/triton/language/math.py new file mode 100644 index 0000000000..582cd876cb --- /dev/null +++ b/third_party/mthreads/python/triton/language/math.py @@ -0,0 +1,249 @@ +from . import core +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x, y = core.binary_op_type_legalization(x, y, _semantic) + return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type) + + +@core._tensor_member_fn +@core.builtin +@_add_math_1arg_docstr("absolute value") +def abs(x, _semantic=None): + x = _semantic.to_tensor(x) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic) + return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_semantic.builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_semantic.builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _semantic=None): + ieee_rounding = core._unwrap_if_constexpr(ieee_rounding) + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + return _semantic.fdiv(x, y, ieee_rounding) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x, y = core.binary_op_type_legalization(x, y, _semantic) + return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + z = _semantic.to_tensor(z) + x, y = core.binary_op_type_legalization(x, y, _semantic) + z, x = core.binary_op_type_legalization(z, x, _semantic) + z, y = core.binary_op_type_legalization(z, y, _semantic) + return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/mthreads/python/triton/language/random.py b/third_party/mthreads/python/triton/language/random.py new file mode 100644 index 0000000000..b4790def87 --- /dev/null +++ b/third_party/mthreads/python/triton/language/random.py @@ -0,0 +1,218 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = tl.constexpr(10) # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = tl.mul(B, _c2, sanitize_overflow=False) + c3 = tl.mul(A, _c0, sanitize_overflow=False) + # raise key + k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False) + k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False) + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + tl.static_assert(seed.dtype.is_int()) + seed = seed.to(tl.uint64) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + + offset_lo = offset.to(tl.uint32) + _0 = offset_lo * 0 + + if tl.constexpr(offset.dtype.primitive_bitwidth) > 32: + offset_hi = (offset >> 32).to(tl.uint32) + else: + offset_hi = _0 + + return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/mthreads/python/triton/language/semantic.py b/third_party/mthreads/python/triton/language/semantic.py new file mode 100644 index 0000000000..86d574d2b3 --- /dev/null +++ b/third_party/mthreads/python/triton/language/semantic.py @@ -0,0 +1,1996 @@ +from __future__ import annotations # remove after python 3.11 +import builtins +import warnings + +from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type +import numbers + +from triton.runtime import driver + +from .._C.libtriton import ir +from . import core as tl + +T = TypeVar('T') +TensorTy = TypeVar('TensorTy') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +class TritonSemantic(Generic[TensorTy]): + tensor: Type[TensorTy] = tl.tensor + lang = tl + + builder: ir.builder + + def __init__(self, builder): + self.builder = builder + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + def program_id(self, axis: int) -> TensorTy: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return self.tensor(self.builder.create_get_program_id(axis), tl.int32) + + def num_programs(self, axis: int) -> TensorTy: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return self.tensor(self.builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, + div_or_mod: bool) -> tl.dtype: + # 0) For scalars we follow semantics similar to PyTorch, namely: + # - If the scalar is of a lower or equal kind (bool < uint < int < fp), + # it doesn't participate in the promotion + if a_is_scalar != b_is_scalar: + scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) + if scalar_ty.kind().value <= tensor_ty.kind().value: + if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): + return tl.float32 + return tensor_ty + # For same-kind scalar/tensor pairs (e.g. fp32 scalar + bf16 + # tensor), keep applying generic promotion rules so fp32 is + # preserved instead of being prematurely demoted to bf16. + + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() and b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + else: + return tl.bfloat16 + if a_ty.is_bf16() or b_ty.is_bf16(): + return tl.float32 + # 5) return fp16 if operands are different fp8 + if a_ty.is_fp8() and b_ty.is_fp8(): + return a_ty if a_ty == b_ty else tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 6 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return self.integer_promote_impl(a_ty, b_ty) + + def to_tensor(self, x, check_type=True): + if isinstance(x, self.tensor): + return x + x = x.value if isinstance(x, tl.constexpr) else x + if isinstance(x, (int, float, bool)): + dtype = self.to_tensor_type(x) + return self.scalar_constant(x, dtype=dtype) + elif check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x + + def to_tensor_type(self, x): + if isinstance(x, tl.dtype): + return x + elif isinstance(x, tl.constexpr_type): + x = x.value + + if isinstance(x, bool): + return tl.int1 + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tl.int32 + elif 2**31 <= x < 2**32: + return tl.uint32 + elif -2**63 <= x < 2**63: + return tl.int64 + elif 2**63 <= x < 2**64: + return tl.uint64 + raise ValueError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = builtins.abs(x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tl.float32 + else: + return tl.float64 + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number, + allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[TensorTy, TensorTy]: + lhs_is_scalar = isinstance(lhs, numbers.Number) + rhs_is_scalar = isinstance(rhs, numbers.Number) + if lhs_is_scalar: + lhs_scalar = lhs + lhs = self.to_tensor(lhs) + if rhs_is_scalar: + rhs_scalar = rhs + rhs = self.to_tensor(rhs) + + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) + if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() + or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): + raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " + "Perform a explicit cast on one of them.") + if ret_sca_ty.is_int(): + if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= + ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}") + if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= + ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}") + lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty) + rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty) + + # implicit broadcasting + lhs, rhs = self.broadcast_impl_value(lhs, rhs) + return lhs, rhs + + def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = self.cast(lhs, tl.int64) + rhs = self.cast(rhs, tl.int64) + ret = binary_op(lhs, rhs, False) + max_value = lhs_sca_ty.get_int_max_value() + max_value = self.scalar_constant(max_value, tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = self.scalar_constant(min_value, tl.int64) + cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value)) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + self.device_assert(cond, msg, None) + + def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + other_handle = other.handle + if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64: + # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive + i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder) + other_handle = self.builder.create_int_cast(other.handle, i64_ty, False) + return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.add) + return self.tensor(self.builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return self.add(input, self.minus(other), sanitize_overflow=False) + # float - float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.sub) + return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type) + # int * int + elif scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.mul) + return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = self.cast(other, input_scalar_ty) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = self.cast(input, other_scalar_ty) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = self.cast(input, tl.float32) + other = self.cast(other, tl.float32) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = self.cast(other, input_scalar_ty) + else: + input = self.cast(input, other_scalar_ty) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type) + + def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = self.cast(input, ret_ty) + other = self.cast(other, ret_ty) + if ret_ty.is_int_signed(): + return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type) + else: + return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True) + ret = self.builder.create_fdiv(input.handle, other.handle) + return self.tensor(ret, input.type) + + def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type) + else: + return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + +############## +# other arithmetic ops +############## + + def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): + x, y = self.binary_op_type_checking_impl(x, y) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): + x, y = self.binary_op_type_checking_impl(x, y) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan): + min, max = self.binary_op_type_checking_impl(min, max) + x, min = self.binary_op_type_checking_impl(x, min) + x, max = self.binary_op_type_checking_impl(x, max) + + dtype = x.dtype + if dtype.is_floating(): + return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + +############## +# bitwise ops +############## + + def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]: + input, other = self.binary_op_type_checking_impl(input, other) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = self.cast(input, ret_sca_ty) + if ret_sca_ty != other_sca_ty: + other = self.cast(other, ret_sca_ty) + return input, other + + def and_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_and(input.handle, other.handle), input.type) + + def or_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_or(input.handle, other.handle), input.type) + + def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type) + + def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy: + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + if not other.type.is_int1(): + other = self.bitcast(other, tl.int1) + return self.and_(input, other) + + def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy: + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + if not other.type.is_int1(): + other = self.bitcast(other, tl.int1) + return self.or_(input, other) + + def not_(self, input: TensorTy): + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + return self.invert(input) + + def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type) + + def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type) + + def shl(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + def plus(self, input: TensorTy) -> TensorTy: + return input + + def minus(self, input: TensorTy) -> TensorTy: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) + return self.sub(_0, input, True) + + def invert(self, input: TensorTy) -> TensorTy: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) + return self.xor_(input, _1) + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// + + def _bool_like(self, v: TensorTy) -> tl.block_type: + return v.type.with_element_ty(tl.int1) + + def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input)) + # == int + elif scalar_ty.is_int(): + return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input)) + # == int + elif scalar_ty.is_int(): + return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + if ret_ty is None: + ret_ty = tl.block_type(tl.int32, shape) + ret_ty_ir = ret_ty.to_ir(self.builder) + return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty) + + def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = self.builder.get_null_value(dtype.to_ir(self.builder)) + elif dtype.is_fp8(): + value = self.builder.get_fp32(value) + value = self.builder.create_fp_trunc(value, dtype.to_ir(self.builder)) + else: + get_value_fn = getattr(self.builder, f"get_{dtype.name}") + value = get_value_fn(value) + return self.tensor(value, dtype) + + def make_scalar(self, value, dtype: tl.dtype) -> TensorTy: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + return self.cast(value, dtype) + # scalar + return self.scalar_constant(value, dtype) + + def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy: + return self.splat(self.make_scalar(value, dtype), shape) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + def splat(self, value: TensorTy, shape: List[int]) -> TensorTy: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty) + + def unsplat(self, value: TensorTy) -> TensorTy: + return self.tensor(self.builder.create_unsplat(value.handle), value.dtype) + + def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: + dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return self.splat(input, shape=dst_shape) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty) + + def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) + + def join(self, a: TensorTy, b: TensorTy) -> TensorTy: + a, b = self.broadcast_impl_value(a, b) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = self.expand_dims(a, 0) + b = self.expand_dims(b, 0) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = self.reshape(ret, [2], can_reorder=False) + + return ret + + def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: + assert (len(a.shape) > 0) + assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = self.builder.create_split(a.handle) + return ( + self.tensor(outLHS, ret_type), + self.tensor(outRHS, ret_type), + ) + + def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy: + if len(input.shape) != len(dims): + raise ValueError( + f"permute dims must have the same length as input shape, got {len(input.shape)} and {len(dims)}") + if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return self.tensor(self.builder.create_trans(input.handle, dims), ret_type) + + def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy: + if not input.type.is_block(): + return self.splat(input, shape) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty) + + def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar) + rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar) + lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + def _str_to_rounding_mode(self, rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy: + src_ty = input.type + if src_ty.is_block(): + dst_ty = src_ty.with_element_ty(dst_ty.scalar) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return self.cast(input, dst_ty) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy: + src_ty = input.type + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + if src_ty.is_block(): + dst_ty = src_ty.with_element_ty(dst_sca_ty) + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + + str(dst_sca_ty)) + + custom_fp8_dtypes = set(getattr(self.builder.options, "custom_fp8_dtypes", ())) + if self.builder.codegen_fns.get("convert_custom_types") is not None: + # Backwards compatibility for backends that only route fp8e4b15 + # through custom casts and have not populated options. + if src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15(): + custom_fp8_dtypes.add("fp8e4b15") + + if str(src_sca_ty) in custom_fp8_dtypes or str(dst_sca_ty) in custom_fp8_dtypes: + assert self.builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return self.tensor( + self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return self.cast(self.cast(input, tl.float32), dst_sca_ty) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(self.builder) + _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) + return self.not_equal(input, _0) + else: + return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend), + dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(self.builder) + _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) + return self.not_equal(input, _0) + elif dst_sca_ty.is_int_signed(): + return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + else: + return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + else: + return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + if bitwidth == 1: + return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64)) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + def _str_to_load_cache_modifier(self, cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + def _str_to_store_cache_modifier(self, cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + def _str_to_eviction_policy(self, eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + def _str_to_padding_option(self, padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + def _str_to_sem(self, sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + def _str_to_scope(self, scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + def _canonicalize_boundary_check(self, boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return self.tensor( + self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), + dst_ty) + + def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + ptr, mask = self.broadcast_impl_value(ptr, mask) + if other is not None: + ptr, other = self.broadcast_impl_value(ptr, other) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + is_bool = elt_ty == tl.int1 + if is_bool: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = self.cast(ptr, ptr_ty) + + # Cast `other` into `elt_ty` type + if other is not None: + other = self.cast(other, elt_ty) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + dst_ty = ptr.type.with_element_ty(elt_ty) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + ret = self.tensor( + self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, + eviction, is_volatile), dst_ty) + if is_bool: + ret = self.cast(ret, tl.int1) + return ret + + def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy: + # Cache, eviction and padding options + cache = self._str_to_load_cache_modifier(cache_modifier) + eviction = self._str_to_eviction_policy(eviction_policy) + padding = self._str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) + + def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str, + eviction_policy: str) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = self._convert_to_ir_values(offsets, require_i64=False) + x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier), + self._str_to_eviction_policy(eviction_policy)) + return self.tensor(x, desc.block_type) + + def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + # implicitly cast to the descriptor's type + value = self.cast(value, desc.dtype) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.ADD + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def _has_native_tma(self, ): + target = driver.active.get_current_target() + return (target.backend == "cuda" and target.arch >= 90) + + def _descriptor_atomic_min_max_supported(self, dtype): + assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" + if dtype in {tl.float16, tl.bfloat16}: + assert self._has_native_tma(), "16-bit float types require native tma support" + + def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + self._descriptor_atomic_min_max_supported(desc.dtype) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.MIN + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + self._descriptor_atomic_min_max_supported(desc.dtype) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.MAX + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.AND + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.OR + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.XOR + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + assert cache_modifier == "", "cache modifier is not supported yet" + assert eviction_policy == "", "eviction policy is not supported yet" + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]]) + y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] + x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder)) + return self.tensor(x, type) + + def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] + self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset) + return self.tensor(None, tl.void) + + def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = self.broadcast_impl_shape(val, block_shape) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = self.cast(val, elt_ty) + + # Build IR + return self.tensor( + self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void) + + def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + ptr_shape = ptr.shape + if mask is None: + ptr, val = self.broadcast_tensors(ptr, val) + else: + ptr, val, mask = self.broadcast_tensors(ptr, val, mask) + if ptr_shape != ptr.shape: + raise ValueError(f"Expected pointer argument to have shape {ptr.shape} but got {ptr_shape}") + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = self.cast(ptr, ptr_ty) + + # Cast to target data type + val = self.cast(val, elt_ty) + + # Build IR + if mask is None: + return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), + tl.void) + + def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str, + eviction_policy: str) -> TensorTy: + # Cache and eviction options + cache = self._str_to_store_cache_modifier(cache_modifier) + eviction = self._str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction) + +######### +# atomic +######### + + def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy: + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, + op: str) -> Tuple[TensorTy, TensorTy, TensorTy]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty is tl.bfloat16 and op != 'add': + raise ValueError("atomic_" + op + " does not support bf16") + if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) + if val is not None: + val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes()) + val = self.cast(val, ptr.type.scalar.element_ty) + if mask is None: + mask_ir = self.builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ty = ptr.type.with_element_ty(tl.int1) + mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir) + mask = self.tensor(mask_ir, mask_ty) + return ptr, val, mask + + def _signbit(self, x: TensorTy) -> TensorTy: + bitwidth = x.dtype.primitive_bitwidth + idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False) + ix = self.bitcast(x, idtype) + signbit = self.lshr(ix, bitwidth - 1) + return self.cast(signbit, tl.int1) + + def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + else: + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = self.bitcast(val, i_type) + i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = self.bitcast(val, ui_type) + ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) + neg = self._signbit(val) + pos = self.not_(neg) + pos_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + self.and_(mask, pos).handle, sem, scope), i_val.type) + neg_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + self.and_(mask, neg).handle, sem, scope), ui_val.type) + ret = self.where(pos, pos_ret, neg_ret) + return self.bitcast(ret, sca_ty) + + def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + else: + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = self.bitcast(val, i_type) + i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = self.bitcast(val, ui_type) + ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) + neg = self._signbit(val) + pos = self.not_(neg) + pos_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + self.and_(mask, pos).handle, sem, scope), i_val.type) + neg_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + self.and_(mask, neg).handle, sem, scope), ui_ptr.type) + ret = self.where(pos, pos_ret, neg_ret) + return self.bitcast(ret, sca_ty) + + def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + def _assert_dot_dtypes_valid(self, lhs_dtype: tl.dtype, rhs_dtype: tl.dtype) -> None: + supported_fp8 = set(getattr(self.builder.options, "supported_fp8_dtypes", ())) + deprecated_fp8 = set(getattr(self.builder.options, "deprecated_fp8_dot_operand_dtypes", ())) + + if lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + assert lhs_dtype.is_fp8() and rhs_dtype.is_fp8(), ( + f"FP8 operands must appear in pairs. Got {lhs_dtype} and {rhs_dtype}") + for dtype in (lhs_dtype, rhs_dtype): + dtype_name = str(dtype) + if supported_fp8 and dtype_name not in supported_fp8 and dtype_name not in deprecated_fp8: + raise ValueError(f"{dtype_name} is not supported in this architecture for dot on this target. " + f"Supported fp8 dtypes: {sorted(supported_fp8)}") + return + + allowed = (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32, tl.float64) + assert lhs_dtype in allowed, f"Unsupported lhs dtype {lhs_dtype}" + assert rhs_dtype in allowed, f"Unsupported rhs dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, (f"Both operands must be same dtype. Got {lhs_dtype} and {rhs_dtype}") + + def _str_to_dot_input_precision(self, input_precision): + assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + if input_precision == "BF16X3": + input_precision = "BF16x3" + if input_precision == "BF16X6": + input_precision = "BF16x6" + return getattr(ir.INPUT_PRECISION, input_precision) + + def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], + max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy: + assert lhs.type.is_block() and rhs.type.is_block() + + self._assert_dot_dtypes_valid(lhs.dtype, rhs.dtype) + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes: + warnings.warn( + "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release" + ) + # We upcast because there's no fp8e4b15 type in MLIR + lhs = self.cast(lhs, tl.float16) + rhs = self.cast(rhs, tl.float16) + + uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8() + uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16() + if uses_fp8e4b8 or uses_fp8e5b16: + type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16" + if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes: + arch = self.builder.options.arch + warnings.warn( + f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. " + f"Please use OCP fp8 variants on {arch} for performance") + lhs = self.cast(lhs, tl.float16) + rhs = self.cast(rhs, tl.float16) + + if input_precision is None: + input_precision = self.builder.options.default_dot_input_precision + + input_precision = self._str_to_dot_input_precision(input_precision) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert self.builder.codegen_fns.get( + "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = self.builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`" + ) + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = self.builder.get_fp32(0) + ret_scalar_ty = tl.float32 + elif lhs.type.scalar.is_fp64(): + _0 = self.builder.get_fp64(0) + ret_scalar_ty = tl.float64 + else: + _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) + else: + acc_handle = acc.handle + assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + else: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") + + return self.tensor( + self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) + + def _str_to_fp_type(self, float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + def _bitcast_to_fp_type(self, val: TensorTy, float_format: str): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": + tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return self.bitcast(val, triton_ty) + + def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale): + if lhs_scale is not None: + scale_factor = 16 if lhs_scale.dtype.is_fp8e4nv() else 32 + lhs_scale_shape = lhs_scale.type.shape + assert lhs_scale_shape[-2:] == [ + M, K // scale_factor + ], f"lhs_scale must be a tensor of shape [..., {M}, {K // scale_factor}]. Got {lhs_scale_shape}" + if rhs_scale is not None: + scale_factor = 16 if rhs_scale.dtype.is_fp8e4nv() else 32 + rhs_scale_shape = rhs_scale.type.shape + assert rhs_scale_shape[-2:] == [ + N, K // scale_factor + ], f"rhs_scale must be a tensor of shape [..., {N}, {K // scale_factor}]. Got {rhs_scale_shape}" + + def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy, + rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool, + lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy: + assert lhs.type.is_block() and rhs.type.is_block() + #TODO: validate types. + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + lhs_format_enum = self._str_to_fp_type(lhs_format) + rhs_format_enum = self._str_to_fp_type(rhs_format) + allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"} + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + lhs = self._bitcast_to_fp_type(lhs, lhs_format) + rhs = self._bitcast_to_fp_type(rhs, rhs_format) + + assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + M, K_LHS = lhs.type.shape[-2:] + K_RHS, N = rhs.type.shape[-2:] + PACKED_A = 2 if lhs_format == "e2m1" else 1 + PACKED_B = 2 if rhs_format == "e2m1" else 1 + PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS + PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS + assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + B = lhs.type.shape[0] if lhs_rank == 3 else None + K = K_LHS + if not lhs_k_pack: + M = M * PACKED_A + else: + K = K * PACKED_A + if not rhs_k_pack: + N = N * PACKED_B + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = self.builder.get_fp32(0) + if acc is None: + acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) + else: + acc_handle = acc.handle + assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + self.verify_scaled_shape(M, N, K, None if lhs_scale_is_none else lhs_scale, + None if rhs_scale_is_none else rhs_scale) + return self.tensor( + self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy: + if condition.dtype != tl.int1: + warnings.warn( + f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + ) + condition = self.cast(condition, tl.int1) + x, y = self.binary_op_type_checking_impl(x, y, True, True) + # x, y are broadcasted + if condition.type.is_block(): + condition, x = self.broadcast_impl_value(condition, x) + x, y = self.broadcast_impl_value(x, y) + else: + condition, _ = self.broadcast_impl_value(condition, x) + ret_ty = x.type + return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + def wrap_tensor(self, x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return self.tensor(x, res_ty) + + def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]: + if axis is None: + inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + assert reduce_op.verify() + + return tuple( + self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn, + reverse: bool) -> Tuple[TensorTy, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + assert scan_op.verify() + + return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + + def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: + assert index.dtype.is_int(), "index must be an integer tensor" + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = self.builder.create_gather(src.handle, index.handle, axis) + return self.wrap_tensor(gather, src.type.scalar, index.type.shape) + +# ===----------------------------------------------------------------------=== +# Map Elementwise +# ===----------------------------------------------------------------------=== + + def broadcast_tensors(self, *inputs): + if not inputs: + return () + head, *tail = inputs + for i in range(len(tail)): + head, tail[i] = self.broadcast_impl_value(head, tail[i]) + for i in range(len(tail) - 1): + head, tail[i] = self.broadcast_impl_value(head, tail[i]) + return (head, *tail) + + def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int, + region_builder_fn) -> Tuple[tl.tensor, ...]: + inputs = self.broadcast_tensors(*inputs) + + assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor" + result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types] + elementwise_op = self.builder.create_map_elementwise( + [t.handle for t in inputs], + [ty.to_ir(self.builder) for ty in result_types], + pack, + ) + region_builder_fn(elementwise_op) + assert elementwise_op.verify() + + return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types)) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + if mask is not None: + mask = self.broadcast_impl_shape(mask, input.shape) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + mask = mask.handle + return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask), + tl.block_type(tl.int32, [num_bins])) + + def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + def debug_barrier(self) -> TensorTy: + return self.tensor(self.builder.create_barrier(), tl.void) + + def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + is_signed = [arg.dtype.is_int_signed() for arg in args] + return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void) + + def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy: + if not self.builder.options.debug: + return + if mask is not None: + cond = self.or_(cond, self.not_(mask)) + return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void) + + def assume(self, cond) -> TensorTy: + return self.tensor(self.builder.create_assume(cond.handle), tl.void) + + def _convert_elem_to_ir_value(self, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if isinstance(elem.value, bool): + return self.builder.get_int1(elem.value) + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return self.builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return self.builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(), + elem.dtype.is_int_signed()) + elif elem.dtype == tl.int64 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + def _convert_to_ir_values(self, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like] + return [self._convert_elem_to_ir_value(list_like, require_i64)] + + def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = self._convert_to_ir_values(shape) + strides = self._convert_to_ir_values(strides) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space)) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + def advance(self, base: TensorTy, offsets) -> TensorTy: + # Convert dynamic offsets to IR values + offsets = self._convert_to_ir_values(offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return self.tensor(self.builder.create_advance(base.handle, offsets), base.type) + + def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy], + block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor: + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, tl.pointer_type) + elem_size = base.dtype.element_ty.primitive_bitwidth // 8 + contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + last_stride = tl._unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + shape = [self.make_scalar(x, tl.int32) for x in shape] + strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides] + + # Check whether `block_shape` is static + block_shape = tl._unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + type = tl.block_type(base.type.element_ty, block_shape) + base_handle = base.handle + is_signed_int = base.type.element_ty.is_int_signed() + + padding = self._str_to_padding_option(padding_option) + + if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer blocks") + + handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], + [s.handle for s in strides], block_shape, is_signed_int, + padding) + return tl.tensor_descriptor(handle, shape, strides, type) diff --git a/third_party/mthreads/python/triton/language/standard.py b/third_party/mthreads/python/triton/language/standard.py new file mode 100644 index 0000000000..490dc75821 --- /dev/null +++ b/third_party/mthreads/python/triton/language/standard.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +from ..runtime.jit import jit, constexpr_function +from . import core +from . import math + +# constexpr utilities + + +@constexpr_function +def _log2(i): + log2 = 0 + n = i + while n > 1: + n >>= 1 + log2 += 1 + return log2 + + +@constexpr_function +def _is_power_of_two(i): + return (i & (i - 1)) == 0 and i != 0 + + +_get_int_dtype = constexpr_function(core.get_int_dtype) + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :type div: Block + """ + return (x + (div - 1)) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, dim=None, keep_dims=False, ieee_rounding=False): + if dim is None: + _dim: core.constexpr = 0 + else: + _dim: core.constexpr = dim + z = x - max(x, _dim, keep_dims=keep_dims) + num = math.exp(z) + den = sum(num, _dim, keep_dims=keep_dims) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x, can_reorder=False): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=can_reorder) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms the indices of a row-major `size_i * size_j` matrix into + the indices of a column-major matrix for each group of `size_g` rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # linear index with respect to the first element in this group + ij = ij % size_gj + # new row and column indices + new_i = off_i + ij % size_g + new_j = ij // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@constexpr_function +def _pick_sum_dtype(in_dtype, dtype): + if dtype is not None: + return dtype + + # For integer bitwidths less than 32, pick int32 with the same sign to + # avoid overflow. + out_dtype = None + if in_dtype.is_int_signed(): + out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None + elif in_dtype.is_int_unsigned(): + out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None + return out_dtype + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum", dtype_arg="dtype") +def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None): + # Pick a default dtype for the reduction if one was not specified. + out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype) + + if out_dtype is not None: + input = input.to(out_dtype) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False): + core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers") + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims) + + +# or reduction + + +@jit +def _or_combine(x, y): + return x | y + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("reduce_or") +def reduce_or(input, axis, keep_dims=False): + core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers") + return core.reduce(input, axis, _or_combine, keep_dims=keep_dims) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum", dtype_arg="dtype") +def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None): + # todo rename this to a generic function name + + input = core._promote_bfloat16_to_float32(input) + out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype) + + if out_dtype is not None: + input = input.to(out_dtype) + + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _indicator(n_dims: core.constexpr, j: core.constexpr): + ar = core.arange(0, 2) + ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j) + return ar + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr): + # compare-and-swap on the ith *innermost* dimension + n_dims: core.constexpr = _log2(x.numel) + + # flip along middle dimension (the bitwise XORs will be optimised away): + idtype = _get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ix = x.to(idtype, bitcast=True) + iy = ix ^ xor_sum(ix, n_dims - 1 - i, True) + y = iy.to(x.dtype, bitcast=True) + + # determines whether we are in the right (rather than left) position along the axis: + is_right = _indicator(n_dims, i) + + # conditional swap: + ret = core.where((x > y) != (flip ^ is_right), y, x) + return ret + + +@jit +def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + flip = _indicator(_log2(x.numel), stage) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, stage - 1 - i) + return x + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + h = core.reshape(x, [2] * _log2(x.numel)) + h = _bitonic_merge_hypercube(h, stage, order) + x = core.reshape(h, x.shape) + return x + + +@jit +def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param k: the number of top elements to select. If none, assume k = x.shape[dim] + :type k: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + + log_n: core.constexpr = _log2(x.shape[_dim]) + log_k: core.constexpr = log_n if k is None else _log2(k) + + n_dims: core.constexpr = _log2(x.numel) + + # reshape to hypercube: + h = core.reshape(x, [2] * n_dims if n_dims else [1]) + + # run first log_k bitonic sort iterations: + for i in core.static_range(1, log_k + 1): + h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending) + + # select top k elements using bitonic top-k + # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf + for i in core.static_range(log_k + 1, log_n + 1): + h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k)) + h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending) + + # reshape back: + x = core.reshape(h, x.shape[:-1] + [2**log_k]) + return x + + +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + return sort_impl(x, dim=dim, descending=descending) + + +@jit +def topk(x, k: core.constexpr, dim: core.constexpr = None): + return sort_impl(x, k=k, dim=dim, descending=True) + + +@jit +def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + n_dims: core.constexpr = _log2(x.shape[-1]) + return _bitonic_merge(x, n_dims, descending, n_dims) + + +@constexpr_function +def _get_flip_dim(dim, shape): + if dim is None: + dim = len(shape) - 1 + if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index + dim += len(shape) + return dim + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along + :type dim: int + """ + core.static_assert(-len(x.shape) <= dim and dim < len(x.shape)) + _dim: core.constexpr = _get_flip_dim(dim, x.shape) + core.static_assert(_is_power_of_two(x.shape[_dim])) + steps: core.constexpr = _log2(x.shape[_dim]) + + # reshape the swap dimension to (2, 2, ..., 2) + idtype = _get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:]) + for i in core.static_range(steps): + y = y ^ xor_sum(y, _dim + i, True) + x = core.reshape(y, x.shape).to(x.dtype, bitcast=True) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + c = core.join(a, b) + + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) + + +@jit +def squeeze(x, dim: core.constexpr): + core.static_assert(x.shape[dim] == 1) + return x.reshape(x.shape[:dim] + x.shape[dim + 1:]) + + +@jit +def unsqueeze(x, dim: core.constexpr): + return x.reshape(x.shape[:dim] + (1, ) + x.shape[dim:]) diff --git a/third_party/mthreads/python/triton/language/target_info.py b/third_party/mthreads/python/triton/language/target_info.py new file mode 100644 index 0000000000..2c1a277f04 --- /dev/null +++ b/third_party/mthreads/python/triton/language/target_info.py @@ -0,0 +1,54 @@ +from triton.runtime import driver +from triton.runtime.jit import constexpr_function + +__all__ = ["current_target"] + + +def current_target(): + try: + active_driver = driver.active + except RuntimeError: + # If there is no active driver, return None + return None + return active_driver.get_current_target() + + +current_target.__triton_builtin__ = True + + +@constexpr_function +def is_cuda(): + target = current_target() + return target is not None and target.backend == "cuda" + + +@constexpr_function +def cuda_capability_geq(major, minor=0): + """ + Determines whether we have compute capability >= (major, minor) and + returns this as a constexpr boolean. This can be used for guarding + inline asm implementations that require a certain compute capability. + """ + target = current_target() + if target is None or target.backend != "cuda": + return False + assert isinstance(target.arch, int) + return target.arch >= major * 10 + minor + + +@constexpr_function +def is_hip(): + target = current_target() + return target is not None and target.backend == "hip" + + +@constexpr_function +def is_hip_cdna3(): + target = current_target() + return target is not None and target.arch == "gfx942" + + +@constexpr_function +def is_hip_cdna4(): + target = current_target() + return target is not None and target.arch == "gfx950" diff --git a/third_party/mthreads/python/triton/runtime/__init__.py b/third_party/mthreads/python/triton/runtime/__init__.py new file mode 100644 index 0000000000..0b3979d28d --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/mthreads/python/triton/runtime/_allocation.py b/third_party/mthreads/python/triton/runtime/_allocation.py new file mode 100644 index 0000000000..f3ef7d56c4 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/_allocation.py @@ -0,0 +1,64 @@ +from typing import Optional, Protocol +from contextvars import ContextVar + + +class Buffer(Protocol): + + def data_ptr(self) -> int: + ... + + +class Allocator(Protocol): + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + ... + + +class NullAllocator: + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " + + "Use triton.set_allocator to specify an allocator.") + + +_NULL_ALLOCATOR = NullAllocator() + +_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=_NULL_ALLOCATOR) + + +def set_allocator(allocator: Allocator) -> None: + """ + The allocator function is called during kernel launch for kernels that + require additional global memory workspace. + """ + _allocator.set(allocator) + + +class _AllocatorWrapper: + """ + Wrapper to provide ContextVar-like .get()/.set() methods. profile_allocator is + used in same way as allocator so it is useful to maintain the interface. + """ + + def __init__(self, allocator: Allocator) -> None: + self._allocator = allocator + + def get(self) -> Allocator: + return self._allocator + + def set(self, allocator: Allocator) -> None: + self._allocator = allocator + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + return self._allocator(size, alignment, stream) + + +_profile_allocator = _AllocatorWrapper(_NULL_ALLOCATOR) + + +def set_profile_allocator(allocator: Optional[Allocator]) -> None: + """ + The profile allocator function is called before kernel launch for kernels + that require additional global memory workspace. + """ + _profile_allocator.set(allocator if allocator is not None else _NULL_ALLOCATOR) diff --git a/third_party/mthreads/python/triton/runtime/_async_compile.py b/third_party/mthreads/python/triton/runtime/_async_compile.py new file mode 100644 index 0000000000..69af447424 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/_async_compile.py @@ -0,0 +1,67 @@ +from __future__ import annotations +from typing import Callable, Optional +from concurrent.futures import Executor, as_completed, Future +from contextvars import ContextVar + +active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None) + + +class FutureKernel: + + def __init__(self, finalize_compile: Callable, future: Future): + self.finalize_compile = finalize_compile + self.kernel = None + self.future = future + + def result(self, ignore_errors: bool = False): + if self.kernel is not None: + return self.kernel + + try: + kernel = self.future.result() + except Exception: + if ignore_errors: + return + else: + raise + self.finalize_compile(kernel) + self.kernel = kernel + return kernel + + def __getattr__(self, name): + # Defer to the compiled kernel so users can interact with this object + # like a normal CompiledKernel without needing to call result() first. + return getattr(self.result(), name) + + +class AsyncCompileMode: + + def __init__(self, executor: Executor, *, ignore_errors=False): + self.executor = executor + self.ignore_errors = ignore_errors + self.raw_futures = [] + self.future_kernels = {} + + def submit(self, key, compile_fn, finalize_fn): + future = self.future_kernels.get(key) + if future is not None: + return future + + future = self.executor.submit(compile_fn) + future._key = key + self.raw_futures.append(future) + future_kernel = FutureKernel(finalize_fn, future) + self.future_kernels[key] = future_kernel + return future_kernel + + def __enter__(self): + if active_mode.get() is not None: + raise RuntimeError("Another AsyncCompileMode is already active") + active_mode.set(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Finalize any outstanding compiles + for future in as_completed(self.raw_futures): + self.future_kernels[future._key].result(self.ignore_errors) + active_mode.set(None) diff --git a/third_party/mthreads/python/triton/runtime/autotuner.py b/third_party/mthreads/python/triton/runtime/autotuner.py new file mode 100644 index 0000000000..883bdcf2de --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/autotuner.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import builtins +import time +import inspect +import hashlib +import json +from functools import cached_property +from typing import Dict, Tuple, List, Optional + +from .. import knobs +from .jit import KernelInterface, JITFunction +from .errors import OutOfResources, PTXASError, AutotunerError +from .driver import driver +from .cache import get_cache_manager, triton_key +from triton._C.libtriton import get_cache_invalidating_env_vars + + +class Autotuner(KernelInterface): + + def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None, + prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None, + cache_results=False): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune': a function used to prune configs. It should have the signature + `prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:` + and return pruned configs. It should return at least one config. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)] + else: + self.configs = configs + self.keys = key + self.cache: Dict[Tuple, Config] = {} + self.arg_names = arg_names + self.cache_results = (cache_results or knobs.autotuning.cache) and not knobs.runtime.interpret + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self._do_bench = do_bench + self.num_warmups = warmup + self.num_reps = rep + try: + import torch + cuda_graph_available = bool(torch.cuda.is_available()) + except Exception: + cuda_graph_available = False + self.use_cuda_graph = use_cuda_graph and cuda_graph_available + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph and cuda_graph_available: + from ..testing import do_bench_cudagraph + self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + @cached_property + def do_bench(self): + if self._do_bench is None: + return driver.active.get_benchmarker() + return self._do_bench + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + verbose = knobs.autotuning.print + if verbose: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if verbose: + print(f"Autotuning failed with {e}") + return [float("inf"), float("inf"), float("inf")] + + def check_disk_cache(self, tuning_key, configs, bench_fn): + # We can't serialize prehooks, so just give up and run the benchmarks. + if not tuning_key or any(cfg.pre_hook for cfg in configs): + bench_fn() + return False + + from triton.compiler.compiler import make_backend + + fn = self.fn + while not isinstance(fn, JITFunction): + fn = fn.fn + + env_vars = get_cache_invalidating_env_vars() + cache_key = [ + triton_key(), + make_backend(driver.active.get_current_target()).hash(), + fn.cache_key, + str(sorted(env_vars.items())), + str(tuning_key), + ] + [str(c) for c in configs] + cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest() + cache = get_cache_manager(cache_key) + file_name = f"{fn.__name__[:150]}.autotune.json" + path = cache.get_file(file_name) + if path: + with open(path, "r") as cached_configs: + timings = json.load(cached_configs)["configs_timings"] + timings = {Config(**config): timing for config, timing in timings} + self.cache[tuning_key] = builtins.min(timings, key=timings.get) + self.configs_timings = timings + return True + + bench_fn() + cache.put( + json.dumps({ + "key": + tuning_key, + "configs_timings": + [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook], + }), file_name, binary=False) + return False + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + + def benchmark(): + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + + if self.cache_results: + used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark) + else: + benchmark() + + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if knobs.autotuning.print and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n" + f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs: Dict) -> List[Config]: + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if not pruned_configs: + raise AutotunerError( + "No valid autotuner configs after pruning. `early_config_prune` should return at least one config.") + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + elif not isinstance(top_k, int): + # Slice index must be an integer + raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int") + + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for autotune_config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **autotune_config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type num_ctas: int + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}). + """ + + def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + self.ir_override = ir_override + + def __setstate__(self, state): + self.kwargs = state.get("kwargs", {}) + self.num_warps = state.get("num_warps", 4) + self.num_stages = state.get("num_stages", 3) + self.num_ctas = state.get("num_ctas", 1) + self.maxnreg = state.get("maxnreg", None) + self.pre_hook = state.get("pre_hook", None) + self.ir_override = state.get("ir_override", None) + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ("ir_override", self.ir_override), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + def __hash__(self): + return hash((*self.all_kwargs().items(), self.pre_hook)) + + def __eq__(self, other): + self_tuple = tuple(( + *self.all_kwargs().items(), + self.pre_hook, + )) + other_tuple = tuple(( + *other.all_kwargs().items(), + other.pre_hook, + )) + return self_tuple == other_tuple + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune': a function used to prune configs. It should have the signature + `prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:` + and return pruned configs. It should return at least one config. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + :param cache_results: whether to cache autotune timings to disk. Defaults to False. + "type cache_results: bool + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + # smallest power-of-two >= x_size + @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])}) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[dict[str, Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/mthreads/python/triton/runtime/build.py b/third_party/mthreads/python/triton/runtime/build.py new file mode 100644 index 0000000000..786f51e54d --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/build.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import functools +import hashlib +import importlib.util +import logging +import os +import shutil +import subprocess +import sysconfig +import tempfile +import re + +from types import ModuleType + +from .cache import get_cache_manager +from .. import knobs + + +def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str], + ccflags: list[str]) -> str: + if impl := knobs.build.impl: + return impl(name, src, srcdir, library_dirs, include_dirs, libraries) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + cc = os.environ.get("CC") + if cc is None: + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError( + "Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.") + scheme = sysconfig.get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + custom_backend_dirs = knobs.build.backend_dirs + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + cc_cmd += [_library_flag(lib) for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + cc_cmd.extend(ccflags) + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + return so + + +def _library_flag(lib: str) -> str: + # Match .so files with optional version numbers (e.g., .so, .so.1, .so.513.50.1) + if re.search(r'\.so(\.\d+)*$', lib) or lib.endswith(".a"): + return f"-l:{lib}" + return f"-l{lib}" + + +@functools.lru_cache +def platform_key() -> str: + from platform import machine, system, architecture + return ",".join([machine(), system(), *architecture()]) + + +def _load_module_from_path(name: str, path: str) -> ModuleType: + spec = importlib.util.spec_from_file_location(name, path) + if not spec or not spec.loader: + raise RuntimeError(f"Failed to load newly compiled {name} from {path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None, + ccflags: list[str] | None = None) -> ModuleType: + key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + suffix = sysconfig.get_config_var("EXT_SUFFIX") + cache_path = cache.get_file(f"{name}{suffix}") + + if cache_path is not None: + try: + return _load_module_from_path(name, cache_path) + except (RuntimeError, ImportError): + log = logging.getLogger(__name__) + log.warning(f"Triton cache error: compiled module {name}.so could not be loaded") + + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, name + ".c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or []) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True) + + return _load_module_from_path(name, cache_path) diff --git a/third_party/mthreads/python/triton/runtime/cache.py b/third_party/mthreads/python/triton/runtime/cache.py new file mode 100644 index 0000000000..186567fed5 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/cache.py @@ -0,0 +1,317 @@ +import json +import os +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List, Optional +import base64 +import hashlib +import functools +import sysconfig + +from triton import __version__, knobs + + +class CacheManager(ABC): + + def __init__(self, key, override=False, dump=False): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = knobs.cache.dump_dir + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = knobs.cache.override_dir + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = knobs.cache.dir + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + try: + with open(grp_filepath) as f: + grp_data = json.load(f) + except Exception: + # exit on corrupted cache. + return None + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + os.removedirs(temp_dir) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = knobs.cache.redis.key_format + self._redis = redis.Redis( + host=knobs.cache.redis.host, + port=knobs.cache.redis.port, + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_cls = knobs.cache.remote_manager_class + if not remote_cache_cls: + raise RuntimeError( + "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class") + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + try: + with open(grp_filepath) as f: + grp_data = json.load(f) + except Exception: + # exit on corrupted cache. + return None + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +def _base32(key): + # Assume key is a hex string. + return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + +def get_cache_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key)) + + +def get_override_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key), override=True) + + +def get_dump_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key), dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return _base32(key) + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def get_cache_key(src, backend, backend_options, env_vars): + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}" + return key diff --git a/third_party/mthreads/python/triton/runtime/driver.py b/third_party/mthreads/python/triton/runtime/driver.py new file mode 100644 index 0000000000..5253176923 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/driver.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import os + +from ..backends import backends, DriverBase + + +def _create_driver() -> DriverBase: + selected = os.environ.get("TRITON_DEFAULT_BACKEND", None) + if selected == "musa": + selected = "mthreads" + if selected: + if selected not in backends: + raise RuntimeError(f"Unknown backend device '{selected}'. Available backends: {list(backends.keys())}") + driver = backends[selected].driver + if not driver.is_active(): + raise RuntimeError(f"Backend device '{selected}' is not active.") + return driver() + else: + active_drivers = [x.driver for x in backends.values() if x.driver.is_active()] + if len(active_drivers) != 1: + raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.") + return active_drivers[0]() + + +class DriverConfig: + + def __init__(self) -> None: + self._default: DriverBase | None = None + self._active: DriverBase | None = None + + @property + def default(self) -> DriverBase: + if self._default is None: + self._default = _create_driver() + return self._default + + @property + def active(self) -> DriverBase: + if self._active is None: + self._active = self.default + return self._active + + def set_active(self, driver: DriverBase) -> None: + self._active = driver + + def reset_active(self) -> None: + self._active = self.default + + +driver = DriverConfig() diff --git a/third_party/mthreads/python/triton/runtime/errors.py b/third_party/mthreads/python/triton/runtime/errors.py new file mode 100644 index 0000000000..d9a1b60bd6 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/errors.py @@ -0,0 +1,46 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) + + +class PTXASError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"PTXAS error: {error_message}" + + +class AutotunerError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"Autotuner error: {error_message}" diff --git a/third_party/mthreads/python/triton/runtime/interpreter.py b/third_party/mthreads/python/triton/runtime/interpreter.py new file mode 100644 index 0000000000..a68255bd64 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/interpreter.py @@ -0,0 +1,1483 @@ +from __future__ import annotations +import ast +import textwrap +import inspect +from typing import Tuple, List, Dict, Callable, TypeVar, Optional + +import math +import numpy as np + +import triton +import triton.language as tl +import dataclasses +from dataclasses import dataclass + +from triton.language.semantic import TritonSemantic +from triton.runtime.jit import KernelInterface +from triton.tools.tensor_descriptor import TensorDescriptor +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter # type: ignore +from .._C.libtriton import ir as _ir # type: ignore +from .._utils import _tuple_create + +T = TypeVar("T") + + +@dataclass +class TensorHandle: + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already available in the data field + attr: a dictionary of attributes + ''' + data: np.ndarray + dtype: tl.dtype + attr: Dict = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if not _validate_np_data_size(self.data, self.dtype): + raise ValueError(f"numpy data itemsize ({self.data.itemsize * 8} bits) exceeds dtype primitive_bitwidth " + f"({self.dtype.primitive_bitwidth} bits) for triton type {self.dtype}") + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, block_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.block_shape = block_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + ptrs_data = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs_data = ptrs_data + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = masks & (off < self.shape[dim].data) & (off >= 0) + ptrs_handle = TensorHandle(ptrs_data, self.base.dtype.scalar) + return ptrs_handle, masks + + +class TensorDescHandle: + + def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle], + block_shape: List[int], padding): + self.base = base + self.ndim = len(shape) + self.shape = shape + self.strides = strides + self.block_shape = block_shape + self.padding = padding + + def validate(self): + assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned" + assert len(self.strides) == self.ndim + assert len(self.block_shape) == self.ndim + assert self.ndim >= 1, "descriptor cannot be 0 dimensional" + + scalar_ty = self.base.dtype.element_ty + itemsize = scalar_ty.primitive_bitwidth // 8 + for stride in self.strides[:-1]: + byte_stride = stride.data.item() * itemsize + assert byte_stride % 16 == 0, "stride must be 16-byte aligned" + assert self.strides[-1].data.item() == 1, "last dim must be contiguous" + + def materialize_pointers(self, offsets: List[TensorHandle]): + assert len(offsets) == self.ndim + scalar_ty = self.base.dtype.element_ty + itemsize = scalar_ty.primitive_bitwidth // 8 + assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned" + + ptrs_data = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs_data = ptrs_data + (itemsize * off * self.strides[dim].data).astype(np.uint64) + masks = masks & (0 <= off) & (off < self.shape[dim].data) + assert ptrs_data.dtype == np.uint64 + ptrs_handle = TensorHandle(ptrs_data, self.base.dtype.scalar) + return ptrs_handle, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: Optional[dict] = None + debug: bool = False + sanitize_overflow: bool = True + arch: Optional[str] = None + supported_fp8_dtypes: Tuple[str, ...] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") + deprecated_fp8_dot_operand_dtypes: Tuple[str, ...] = () + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str, ...] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + backend_name: str = "interpreter" + + +def _validate_np_data_size(np_array, tl_dtype): + if isinstance(tl_dtype, tl.pointer_type): + return True + + np_dtype_bitwidth = np_array.itemsize * 8 + tl_dtype_bitwidth = tl_dtype.primitive_bitwidth + + # numpy lowest itemsize is at least 8 bits + if tl_dtype_bitwidth < 8: + tl_dtype_bitwidth = 8 + + if np_dtype_bitwidth > tl_dtype_bitwidth: + return False + return True + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic): + return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int1_ty(self): + return tl.int1 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + output = op(lhs.data, rhs.data) + tl_dtype = lhs.dtype.scalar + + if not _validate_np_data_size(output, tl_dtype): + output = output.astype(_get_np_dtype(tl_dtype)) + + return TensorHandle(output, tl_dtype) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + create_int_to_ptr = create_bitcast + create_ptr_to_int = create_bitcast + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + output = op(lhs.data, rhs.data, other.data) + tl_dtype = other.dtype.scalar + + if not _validate_np_data_size(output, tl_dtype): + output = output.astype(_get_np_dtype(tl_dtype)) + + return TensorHandle(output, tl_dtype) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, ret_ty, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins, mask): + if mask is None: + mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1) + + # By default np.histogram returns int64 dtype values + # Docs specify that returned dtype is taken based on optional weights.dtype + # This is fix for interpreter cases where for example int32 tensor is being passed + # But unexpectedly int64 values are being returned causing + # tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption + dummy_weights = np.ones_like(data.data, dtype=data.data.dtype) + + # force all masked elements to zero + data = np.where(mask.data, data.data, np.zeros_like(data.data)) + histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0] + # remove overcounted elements + histogram[0] -= np.logical_not(mask.data).sum() + return TensorHandle(histogram, tl.int32) + + def create_gather(self, src, indices, axis): + return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, ret_ty, arg): + shape = ret_ty.shape + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_unsplat(self, arg): + return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values, isSigned): + # NOTE: the `isSigned` variable is not really used here; because Signness is already known + # by `values` themselves in python interpreter, thus not really needed here; + # it is only used for triton PrintOpToLLVM to correctly construct the format specifier. + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message}" + + def create_assume(self, condition): + assert condition, "Assume failed" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle], + tensor_shape: List[int], is_signed: bool, padding: str = "zero"): + desc = TensorDescHandle(base, shape, strides, tensor_shape, padding) + desc.validate() + return desc + + def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier, + eviction_policy): + assert isinstance(desc, TensorDescHandle) + ptrs, mask = desc.materialize_pointers(indices) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + padding = desc.padding + if padding == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding {padding}") + return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier, + eviction_policy=eviction_policy, is_volatile=False) + + def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]): + ptrs, mask = desc.materialize_pointers(indices) + return self.create_masked_store(ptrs, value, mask, None, None) + + def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type): + dtype = desc.base.dtype.element_ty + np_dtype = _get_np_dtype(dtype) + result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype) + cache_modifier = None + eviction_policy = None + for i, x_offset in enumerate(x_offsets.data): + indices = [TensorHandle(x_offset, tl.int32), y_offset] + result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data + return TensorHandle(result, dtype) + + def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle, + y_offset: TensorHandle): + for i, x_offset in enumerate(x_offsets.data): + slice = TensorHandle(value.data[i], value.dtype) + indices = [TensorHandle(x_offset, tl.int32), y_offset] + self.create_descriptor_store(desc, slice, indices) + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + elif np_type == np.bool_: + return TensorHandle(np.full(1, True, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +_MISSING = object() +interpreter_builder = InterpreterBuilder() +interpreter_semantic: TritonSemantic = TritonSemantic(interpreter_builder) + + +class _LangPatchScope: + """Tracks patched attributes so they can be restored.""" + + def __init__(self) -> None: + self._changes: list[tuple[object, str, object]] = [] + + def set_attr(self, obj: object, name: str, value: object) -> None: + original = getattr(obj, name, _MISSING) + self._changes.append((obj, name, original)) + setattr(obj, name, value) + + def restore(self) -> None: + while self._changes: + obj, name, original = self._changes.pop() + if original is _MISSING: + delattr(obj, name) + else: + setattr(obj, name, original) + + +def _patch_attr(obj, name, member, builder, scope: _LangPatchScope): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_semantic"}, _semantic=interpreter_semantic)) + scope.set_attr(obj, name, new_member) + + +def _patch_builtin(pkg, builder, scope: _LangPatchScope): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder, scope) + + +def _patch_lang_tensor(tensor, scope: _LangPatchScope): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype) + assert self.type.is_block() + block_shape = list(self.type.shape) + block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1] + res_ty = tl.core.block_type(self.dtype, block_shape) + return tl.core.tensor(handle, res_ty) + + scope.set_attr(tensor, "__index__", lambda self: int(self.handle.data.squeeze())) + scope.set_attr(tensor, "__bool__", lambda self: _get_bool(self)) + scope.set_attr(tensor, "__repr__", lambda self: repr(self.handle.data)) + scope.set_attr(tensor, "__str__", lambda self: str(self.handle.data)) + scope.set_attr(tensor, "T", property(_get_transpose)) + + +class ReduceScanOpInterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + np_dtype = _get_np_dtype(dtype) + if hasattr(ret, "shape") and ret.shape: + ret = ret.astype(np_dtype) + ret_type = tl.block_type(dtype, list(ret.shape)) + else: + ret = np.array([ret], dtype=np_dtype) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl must be implemented by subclasses") + + def apply(self, input): + if not isinstance(input, tuple): + return self.apply((input, ))[0] + self.check_tensor(input) + ret = self.apply_impl(input) + return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, ) + + +class ReduceOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return ret + + +def _patch_reduce_scan(scope: _LangPatchScope): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + scope.set_attr(tl, "reduce", _new_reduce) + scope.set_attr(tl, "associative_scan", _new_scan) + scope.set_attr(tl.core, "reduce", _new_reduce) + scope.set_attr(tl.core, "associative_scan", _new_scan) + + +def _patch_lang_core(lang, scope: _LangPatchScope): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + scope.set_attr(lang, "range", _new_range) + scope.set_attr(lang, "static_range", _new_range) + scope.set_attr(lang, "static_assert", _new_static_assert) + scope.set_attr(lang, "static_print", print) + scope.set_attr(lang.dtype, "to_ir", _new_to_ir) + scope.set_attr(lang, "multiple_of", partial(_set_attr, name="tt.divisibility")) + scope.set_attr(lang, "max_contiguous", partial(_set_attr, name="tt.contiguity")) + scope.set_attr(lang, "max_constancy", partial(_set_attr, name="tt.constancy")) + + _patch_reduce_scan(scope) + + +def _patch_lang(fn): + scope = _LangPatchScope() + langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]] + assert len(langs) >= 1, "triton.language must be visible from within jit'd function" + for lang in langs: + _patch_builtin(lang, interpreter_builder, scope) + _patch_builtin(lang.tensor, interpreter_builder, scope) + if lang == tl: + _patch_builtin(lang.math, interpreter_builder, scope) + _patch_lang_tensor(lang.tensor, scope) + _patch_lang_core(lang, scope) + _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder, scope) + return scope + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + elif isinstance(arg, tuple): + return _tuple_create(arg, map(_implicit_cvt, arg)) + elif isinstance(arg, TensorDescriptor): + strides = [_implicit_cvt(s) for s in arg.strides] + assert arg.strides[-1] == 1 + strides[-1] = tl.constexpr(1) + return interpreter_semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base), + shape=[_implicit_cvt(s) for s in arg.shape], strides=strides, + block_shape=[tl.constexpr(b) for b in arg.block_shape], + padding_option=arg.padding) + return arg + + +def _unwrap_tensor(t): + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t + + +def _rewrap_tensor(t, original_tensor): + if isinstance(original_tensor, triton.runtime.jit.TensorWrapper): + return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype) + return t + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid, pre_run_hooks=[]): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + self.pre_run_hooks = pre_run_hooks + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + storages = {} + + def _to_cpu(arg): + if isinstance(arg, tuple): + return _tuple_create(arg, map(_to_cpu, arg)) + elif isinstance(arg, TensorDescriptor): + return TensorDescriptor( + _to_cpu(arg.base), + arg.shape, + arg.strides, + arg.block_shape, + arg.padding, + ) + elif not hasattr(arg, "data_ptr"): + return arg + + unwrapped_arg = _unwrap_tensor(arg) + if unwrapped_arg.untyped_storage().data_ptr() not in storages: + storage = unwrapped_arg.untyped_storage() + storages[storage.data_ptr()] = storage.cpu() + + storage = storages[unwrapped_arg.untyped_storage().data_ptr()] + cpu_arg = unwrapped_arg.new_empty(0, device='cpu') + cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride()) + cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg) + return cpu_arg + + args_hst = [_to_cpu(arg) for arg in args_dev] + + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + kwargs_hst[key] = _to_cpu(value) + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + storages = {} + + def _from_cpu(arg_dev, arg_hst): + if hasattr(arg_dev, "data_ptr"): + # No need to rewrap because this just modifies internal + arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst) + storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage()) + elif isinstance(arg_dev, tuple): + for (arg_dev, arg_hst) in zip(arg_dev, arg_hst): + _from_cpu(arg_dev, arg_hst) + elif isinstance(arg_dev, TensorDescriptor): + _from_cpu(arg_dev.base, arg_hst.base) + + for arg_dev, arg_hst in zip(args_dev, args_hst): + _from_cpu(arg_dev, arg_hst) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + _from_cpu(kwarg_dev, kwarg_hst) + + for (arg_dev, arg_hst) in storages.values(): + arg_dev.copy_(arg_hst) + + def __call__(self, *args_dev, **kwargs): + # Removes not used reserved keywords from kwargs + # Triton doesn't support keyword-only, variable positional or variable keyword arguments + # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) + argspec = inspect.getfullargspec(self.fn) + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # run pre-run hooks + for hook in self.pre_run_hooks: + hook(*args_hst, **kwargs_hst) + # remaps core language functions to interpreted ones + patch_scope = _patch_lang(self.fn) + try: + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + if triton.knobs.compilation.front_end_debugging: + raise + raise InterpreterError(repr(e)) from e + finally: + patch_scope.restore() + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class ASTTransformer(ast.NodeTransformer): + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # interpreter_semantic.to_tensor(value, False) + node.value = ast.Call( + func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor", + ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[]) + return node + + +class FunctionRewriter: + ast_transformer = ASTTransformer() + + def __init__(self, fn, **kwargs): + self.fn = fn + self.kwargs = kwargs + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + + def rewrite_ast(self): + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, _ = inspect.getsourcelines(self.fn) + except Exception: + return self.fn + + # truncate lines before def + # @triton.autotune(...) + # ... + # @triton.jit + # ... + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.def_lineno = self._find_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_def(self, lines): + def_lineno = 0 + # Line numbers start from 1 + for i, line in enumerate(lines): + if line.strip().startswith("def "): + def_lineno = i + 1 + return def_lineno + + def _prepare_source(self, lines): + lines = lines[self.def_lineno - 1:] + src = ''.join(lines) + return textwrap.dedent(src) + + def _transform_ast(self, src): + # src is like: + # 1: def foo(...): + # 2: ... + parsed_ast = ast.parse(src) + transformed_ast = self.ast_transformer.visit(parsed_ast) + ast.fix_missing_locations(transformed_ast) + inc_lineno = self.def_file_lineno - 1 + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') + local_namespace = {**self.kwargs} + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + if key not in fn_globals: + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__] + + +class InterpretedFunction(KernelInterface[T]): + # Cache all rewritten functions + rewritten_fn: Dict[Callable, Callable] = {} + + def __init__(self, fn, **kwargs) -> None: + self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) + self.kwargs = kwargs + self.pre_run_hooks = [] + + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def run(self, *args, grid, warmup, **kwargs): + if warmup: + return + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid, self.pre_run_hooks)(*args, **kwargs) + + def add_pre_run_hook(self, hook): + assert callable(hook) + self.pre_run_hooks.append(hook) + + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] + + @property + def __name__(self): + return self.fn.__name__ + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + fn = self.rewrite() + try: + return fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/mthreads/python/triton/runtime/jit.py b/third_party/mthreads/python/triton/runtime/jit.py new file mode 100644 index 0000000000..ba21a58e19 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/jit.py @@ -0,0 +1,1134 @@ +from __future__ import annotations, division +import ast +import copy +import hashlib +import inspect +import itertools +import threading +import re +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, overload, Dict, Any, Tuple + +from triton.backends import BaseBackend +from types import ModuleType +from .. import knobs +from .driver import driver +from . import _async_compile +from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, is_namedtuple +from .cache import get_cache_key +from triton._C.libtriton import get_cache_invalidating_env_vars, native_specialize_impl, ir + +TRITON_MODULE = "triton.language" +GLUON_MODULE = "triton.experimental.gluon.language" + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, nonlocals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + self.nonlocals = nonlocals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + self.supported_modules = { + GLUON_MODULE, + TRITON_MODULE, + "copy", + "math", + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + assert isinstance(func, JITCallable) + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def record_reference(self, val, var_dict=None, name=None): + from ..language.core import constexpr + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if val is None or type(val) is ModuleType: + return + + if getattr(val, "__triton_aggregate__", False): + for attr in val.hash_attrs: + self.record_reference(attr) + return + + if getattr(val, "__triton_builtin__", False): + return + + # Stubs that aren't real functions + if getattr(val, "__module__", "") == "triton.language.extra.libdevice": + return + + if isinstance(val, JITCallable): + self._update_hash(val) + return + + if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr): + raise RuntimeError(f"Unsupported function referenced: {val}") + + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + if self.visiting_arg_default_value: + return + + if var_dict is not None: + self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict) + return + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + def name_lookup(name): + val = self.globals.get(name, None) + if val is not None: + return val, self.globals + val = self.nonlocals.get(name, None) + if val is not None: + return val, self.nonlocals + return None, None + + val, var_dict = name_lookup(node.id) + if node.id in self.supported_python_builtins: + return val + + self.record_reference(val, var_dict, node.id) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + lhs_name = getattr(lhs, "__name__", "") + if lhs is None or lhs_name in self.supported_modules: + return None + ret = getattr(lhs, node.attr) + self.record_reference(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + import triton.language.core as core + if isinstance(ty, str): + ty = ty.strip() + if ty.startswith("const "): + ty = ty.removeprefix("const") + ty = _normalize_ty(ty) + assert ty.startswith("*") + return "*k" + ty[1:] + if ty.endswith("*"): + return "*" + _normalize_ty(ty[:-1]) + if ty.startswith("*"): + return "*" + _normalize_ty(ty[1:]) + if ty.startswith("tl."): + return _normalize_ty(ty.removeprefix("tl.")) + elif isinstance(ty, core.pointer_type): + return f"*{_normalize_ty(ty.element_ty)}" + elif isinstance(ty, core.dtype): + ty = ty.name + elif isinstance(ty, type): + ty = ty.__name__ + else: + ty = str(ty) + return type_canonicalisation_dict.get(ty.replace("_t", ""), ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self) -> str: + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self) -> str: + a = self.annotation + if a.startswith("*k"): + a = a[2:] + elif a.startswith("*"): + a = a[1:] + if a in set(type_canonicalisation_dict.values()): + return self.annotation + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + if self.is_constexpr: + return False + return "const" in self.annotation or self.annotation.startswith("*k") + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def mangle_type(arg, specialize=False): + is_const = False + align = True + return native_specialize_impl(BaseBackend, arg, is_const, specialize, align)[0] + + +class KernelInterface(Generic[T]): + run: T + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def run(self, *args, grid, warmup, **kwargs): + raise NotImplementedError("run not implemented") + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key, target): + constants = { + key: str(value) if value.__class__.__name__ == "dtype" else {"constexpr": value.value} + if value.__class__.__name__ == "constexpr" else {"jit_function": f"{value.module}:{value.fn.__qualname__}"} + if value.__class__.__name__ == "JITFunction" else value + for key, value in constants.items() + } + + import json + obj = { + 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals': + list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()), + 'options': options.__dict__, 'key': key, 'target': target.__dict__ + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + assert len(sig.parameters) == len(kparams) + # Create the function argument list and the dict entries for the return statement + specialization = [] + # signature + for name, kp in zip(sig.parameters.keys(), kparams): + if kp.is_constexpr: + specialization.append(f'("constexpr", {name})') + else: + is_const = 'True' if kp.is_const else 'False' + specialize = 'False' if kp.do_not_specialize else 'True' + align = 'False' if kp.do_not_specialize_on_alignment else 'True' + ret = f"specialize_impl(backend, {name}, {is_const}, {specialize}, {align})" + if kp.annotation_type: + if isinstance(kp.annotation_type, str): + if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]: + # we do not specialize non-constexpr floats and bools: + specialize = False + if specialize: + specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]') + else: + # skip runtime specialization: + specialization.append(f'("{kp.annotation_type}", None)') + else: + specialization.append(f"{ret}") + + # compute argument string for a given parameter + arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}" + func_body = f""" +def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}): + params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}} + specialization = [{','.join(specialization)}] + return params, specialization, options +""" + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + specialize_impl = native_specialize_impl + func_namespace["specialize_impl"] = specialize_impl + func_namespace["backend"] = backend + func_namespace["JITCallable"] = JITCallable + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +def get_full_name(fn): + return f"{fn.__module__}.{fn.__qualname__}" + + +class JITCallable: + + def __init__(self, fn): + self.fn = fn + self.signature = inspect.signature(fn) + try: + self.raw_src, self.starting_line_number = inspect.getsourcelines(fn) + except OSError as e: + raise ValueError("@jit functions should be defined in a Python file") from e + self._fn_name = get_full_name(fn) + self._hash_lock = threading.RLock() + + # function source code (without decorators) + src = textwrap.dedent("".join(self.raw_src)) + src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():] + self._src = src + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__qualname__ = fn.__qualname__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + def get_capture_scope(self): + fn = self.fn + if fn.__closure__ is None: + return self.__globals__ + nonlocals = {name: cell.cell_contents for name, cell in zip(fn.__code__.co_freevars, fn.__closure__)} + return self.__globals__ | nonlocals + + @property + def cache_key(self) -> str: + # TODO : hash should be attribute of `self` + with self._hash_lock: + if self.hash is not None: + return self.hash + # Set a placeholder hash to break recursion in case the function + # transitively calls itself. The full hash is set after. + self.hash = f"recursion:{self._fn_name}" + nonlocals = inspect.getclosurevars(self.fn).nonlocals + dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals, + src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + + from triton.language.core import constexpr + self.hash += str([(name, val) + for (name, _), (val, _) in self.used_global_vals.items() + if isinstance(val, constexpr)]) + self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest() + return self.hash + + def __hash__(self): + return hash(self.cache_key) + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self._src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + @property + def type(self): + from triton.language.core import constexpr_type + return constexpr_type(self) + + def _flatten_ir(self, handles: list[ir.value]) -> None: + pass + + def _unsafe_update_src(self, new_src): + """ + The only method allowed to modify src. + Bypasses the __setattr__ restriction by calling super().__setattr__ directly. + + Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None. + """ + self.hash = None + self._src = new_src + + def _set_src(self): + raise AttributeError("Cannot set attribute 'src' directly. " + "Use '_unsafe_update_src()' and manually clear `.hash` of all callers" + "instead.") + + def _get_src(self): + return self._src + + src = property(fget=_get_src, fset=_set_src) + + +_triton_jit_function_registry = {} + + +@dataclass +class JitFunctionInfo: + module: ModuleType + name: str + jit_function: JITFunction + + +def compute_cache_key(kernel_key_cache, specialization, options): + key = (tuple(specialization), str(options)) + cache_key = kernel_key_cache.get(key, None) + if cache_key is not None: + return cache_key + + # Replace JITCallable objects with their hash, so the cache key will change if the src is updated + def replace_callables(obj): + if isinstance(obj, list): + return [replace_callables(arg) for arg in obj] + elif is_namedtuple(obj): + results = [replace_callables(arg) for arg in obj] + return obj.__class__(*results) + elif isinstance(obj, tuple): + return tuple(replace_callables(arg) for arg in obj) + elif isinstance(obj, JITCallable): + return obj.cache_key + return obj + + cache_key = str(replace_callables(specialization)) + str(options) + kernel_key_cache[key] = cache_key + return cache_key + + +def convert_to_tuple_if_list(item): + # If the incoming item is a list, recursively iterate through it to convert all lists therein into tuples + if not isinstance(item, list): + return item + + # The value must be a list at this point + for i, nested_value in enumerate(item): + item[i] = convert_to_tuple_if_list(nested_value) + + return tuple(item) + + +class JITFunction(JITCallable, KernelInterface[T]): + + def is_gluon(self): + return False + + def _call_hook( + self, + hook, + key, + signature, + target, + device, + constants, + options, + configs, + is_warmup, + ) -> bool | None: + if not hook: + return None + + name = self.fn.__qualname__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})" + full_name = get_full_name(self.fn) + + specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key, + target) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'launch_cooperative_grid': options.launch_cooperative_grid, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + target = driver.active.get_current_target() + backend = make_backend(target) + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + binder = create_function_from_signature(self.signature, self.params, backend) + return {}, {}, target, backend, binder + + def _pack_args(self, backend, kwargs, bound_args, specialization, options): + # options + options = backend.parse_options(kwargs) + # signature + sigkeys = [x.name for x in self.params] + sigvals = [x[0] for x in specialization] + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + # check arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in kwargs: + if k not in options.__dict__ and k not in sigkeys: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + # constexprs + constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr") + constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs} + # attributes + attrvals = ['' if x[0] == 'constexpr' else x[1] for x in specialization] + attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) + attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs} + + return options, signature, constexprs, attrs + + def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug + kwargs["instrumentation_mode"] = knobs.compilation.instrumentation_mode + + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device] + # specialization is list[tuple[str, Any]], where first element of tuple is + # the type and the second parameter is the 'specialization' value. + bound_args, specialization, options = binder(*args, **kwargs) + + # add a cache field to the kernel specializations for kernel specific + # pass pipelines + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + specialization.append(f'("custom_pipeline", {inspect_stages_hash})') + + key = compute_cache_key(kernel_key_cache, specialization, options) + kernel = kernel_cache.get(key, None) + + # Kernel is not cached; we have to compile. + if kernel is None: + options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization, + options) + + kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup) + if kernel is None: + return None + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values()) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values()) + return kernel + + def repr(self, _): + return self._fn_name if self._repr is None else self._repr(_) + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + super().__init__(fn) + self.module = fn.__module__ + self.version = version + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self._repr = repr + self.launch_metadata = launch_metadata + # Register for simple deserialization of JITFunction constants + _triton_jit_function_registry[f"{self.module}:{self.fn.__qualname__}"] = self + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # cache of just-in-time compiled kernels + self.device_caches = defaultdict(self.create_binder) + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + def preload(self, specialization_data): + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self._fn_name: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}") + constant_keys = map(tuple, deserialized_obj['constant_keys']) + constant_vals = deserialized_obj['constant_vals'] + _, _, target, backend, _ = self.device_caches[device] + deserialized_target = deserialized_obj['target'] + # TODO: we could support loading a kernel signature serialized on a different target however + # currently options are target specific so we would need to change that. + if target.__dict__ != deserialized_target: + raise RuntimeError(f"Specialization data is for {deserialized_target} but trying to preload for {target}") + + def _decode_constant(value): + if tl.dtype.is_dtype(value): + return tl.dtype(value) + if isinstance(value, dict): + if 'constexpr' in value: + return tl.constexpr(value['constexpr']) + if 'jit_function' in value: + jf_key = value['jit_function'] + if jf_key in _triton_jit_function_registry: + return _triton_jit_function_registry[jf_key] + raise RuntimeError(f"Unable to resolve JITFunction {jf_key} for preload") + return value + + constexprs = {key: _decode_constant(value) for key, value in zip(constant_keys, constant_vals)} + attrs_keys = map(tuple, deserialized_obj['attrs_keys']) + attrs_vals = deserialized_obj['attrs_vals'] + attrs = dict(zip(attrs_keys, attrs_vals)) + # JSON serializes tuples as lists, so they need to be converted back; + # This can be done unconditionally, since lists are not accepted in Triton kernel signatures. + signature = {key: convert_to_tuple_if_list(value) for key, value in deserialized_obj['signature'].items()} + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + options = backend.parse_options(options) + return self._do_compile( + key, + signature, + device, + constexprs, + options, + attrs, + warmup=True, + ) + + def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup): + kernel_cache, _, target, backend, _ = self.device_caches[device] + + if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, target, device, constexprs, options, [attrs], + warmup): + return None + src = self.ASTSource(self, signature, constexprs, attrs) + + async_mode = _async_compile.active_mode.get() + if async_mode is not None: + + env_vars = get_cache_invalidating_env_vars() + cache_key = get_cache_key(src, backend, options, env_vars) + + def async_compile(): + return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars) + + def finalize_compile(kernel): + kernel_cache[key] = kernel + self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, target, device, constexprs, + options, [attrs], warmup) + + kernel = async_mode.submit(cache_key, async_compile, finalize_compile) + kernel_cache[key] = kernel + else: + kernel = self.compile(src, target=target, options=options.__dict__) + kernel_cache[key] = kernel + self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, target, device, constexprs, options, + [attrs], warmup) + return kernel + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__qualname__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> KernelInterface[T]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if knobs.runtime.interpret: + from .interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype, shape=None): + if shape is None: + shape = [1] + self.dtype = dtype + self.shape = shape + + def stride(self): + strides = [1] + for size in self.shape[1:]: + strides.append(strides[-1] * size) + return tuple(reversed(strides)) + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, *args): + return self.base.stride(*args) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + def new_empty(self, sizes): + return TensorWrapper(self.base.new_empty(sizes), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITCallable): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + begin_line = base_fn.starting_line_number + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(base_fn.raw_src): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +class BoundConstexprFunction(JITCallable): + + def __init__(self, instance, fn): + self.__self__ = instance + self.__func__ = fn + + @property + def cache_key(self): + return self.__func__.cache_key + + def __call__(self, *args, **kwargs): + return self.__func__(self.__self__, *args, **kwargs) + + +class ConstexprFunction(JITCallable): + + def __init__(self, fn): + super().__init__(fn) + + def __get__(self, obj, objclass): + # Create a bound function to support constexpr_function methods + if obj is not None: + return BoundConstexprFunction(obj, self) + return self + + def __call__(self, *args, _semantic=None, **kwargs): + from triton.language.core import _unwrap_if_constexpr, constexpr + # de-constexpr arguments and discard the _semantic keyword argument: + args = [_unwrap_if_constexpr(x) for x in args] + kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} + + # call the raw Python function f: + res = self.fn(*args, **kwargs) + + if _semantic is None: + # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function + return res + + # convert result back to a Triton constexpr: + if knobs.runtime.interpret: + return res # No constexpr in interpreter + return constexpr(res) + + +def constexpr_function(fn): + """ + Wraps an arbitrary Python function so that it can be called at + compile-time on constexpr arguments in a Triton function and + returns a constexpr result. + """ + return ConstexprFunction(fn) diff --git a/third_party/mthreads/python/triton/testing.py b/third_party/mthreads/python/triton/testing.py new file mode 100644 index 0000000000..91493f1bd1 --- /dev/null +++ b/third_party/mthreads/python/triton/testing.py @@ -0,0 +1,571 @@ +import functools +import math +import os +import statistics +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from . import runtime +from .backends import backends as _available_backends + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +# pure Python implementation of np.quantile/torch.quantile +# to avoid unnecessary runtime dependency on numpy/torch + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(q) for q in q] + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = _quantile(times, quantiles) + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) + + +_DEVICE_TYPE_TO_BACKEND = { + "cuda": "nvidia", "nvidia": "nvidia", "hip": "amd", "amd": "amd", "musa": "mthreads", "mthreads": "mthreads" +} + + +@functools.lru_cache(maxsize=None) +def _get_backend_driver(backend_name: str): + if backend_name not in _available_backends: + available = ", ".join(sorted(_available_backends.keys())) + raise RuntimeError(f"Unsupported device_type/backend '{backend_name}'. " + f"Available Triton backends: [{available}]") + driver_cls = _available_backends[backend_name].driver + if not driver_cls.is_active(): + raise RuntimeError(f"Backend '{backend_name}' is not active.") + return driver_cls() + + +def _resolve_benchmark_driver(device_type: str | None): + if device_type is None: + return runtime.driver.active + backend_name = _DEVICE_TYPE_TO_BACKEND.get(device_type.lower(), device_type.lower()) + return _get_backend_driver(backend_name) + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + # Rewrite to avoid possible division by 0 issues with fast benchmarks + if estimate_ms == 0: + n_repeat = 1000 + else: + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(ret, quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type=None): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + :param device_type: Optional device/backend selector for benchmarking (e.g. "cuda", "hip", "musa"). + When omitted, use the current active Triton driver. + :type device_type: str, optional + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + benchmark_driver = _resolve_benchmark_driver(device_type) + di = benchmark_driver.get_device_interface() + + fn() + di.synchronize() + + cache = benchmark_driver.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + benchmark_driver.clear_cache(cache) + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + benchmark_driver.clear_cache(cache) + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + return _summarize_statistics(times, quantiles, return_mode) + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True, err_msg=err_msg) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean_labels = [f'{x} ({bench.ylabel})' for x in bench.line_names] + y_min_labels = [f'{x}-min ({bench.ylabel})' for x in bench.line_names] + y_max_labels = [f'{x}-max ({bench.ylabel})' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean_labels + y_min_labels + y_max_labels) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, (mean_label, min_label, max_label) in enumerate(zip(y_mean_labels, y_min_labels, y_max_labels)): + y_min, y_max = df[min_label], df[max_label] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[mean_label], label=mean_label, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + y_mean_labels] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + try: + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + finally: + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + with open(os.path.join(save_path, "results.html"), "w") as html: + html.write("\n") + for bench in benchmarks[:len(result_dfs)]: + html.write(f"\n") + html.write("\n") + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + + from .runtime import driver + if device is None: + device = driver.active.get_device_interface().current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/mthreads/python/triton/tools/__init__.py b/third_party/mthreads/python/triton/tools/__init__.py new file mode 100644 index 0000000000..fb4e3a7a82 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/__init__.py @@ -0,0 +1 @@ +from triton._C.libtriton.linear_layout import LinearLayout diff --git a/third_party/mthreads/python/triton/tools/build_extern.py b/third_party/mthreads/python/triton/tools/build_extern.py new file mode 100644 index 0000000000..8f0168d59d --- /dev/null +++ b/third_party/mthreads/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/mthreads/python/triton/tools/compile.py b/third_party/mthreads/python/triton/tools/compile.py new file mode 100644 index 0000000000..03a210a50a --- /dev/null +++ b/third_party/mthreads/python/triton/tools/compile.py @@ -0,0 +1,211 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import triton +import triton.backends + + +@dataclass +class CompileArgs: + ''' + A class to contain arguments from command-line parser. + ''' + path: str = '' + kernel_name: str = '' + signature: str = '' + grid: str = '' + target: str | None = None + num_warps: int = 1 + num_stages: int = 3 + out_name: str | None = None + out_path: Path | None = None + + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + + +def main(): + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument( + "--target", "-t", type=str, default=None, + help="The target to compile towards, in format of '::'; " + "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target") + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + cli_args = parser.parse_args() + args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well. + compile_kernel(args) + + +def compile_kernel(args: CompileArgs): + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + for key, value in hints.items(): + if value == 1: + constants[kernel.arg_names[key[0]]] = value + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{k}={v}" for k, v in constants.items()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} + kernel.create_binder() + src = kernel.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) + target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \ + if args.target else triton.runtime.driver.active.get_current_target() + backend = triton.compiler.make_backend(target) + kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages} + options = backend.parse_options(kwargs) + ccinfo = triton.compile(src, target=target, options=options.__dict__) + + if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0: + raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented") + if ccinfo.metadata.profile_scratch_size > 0: + raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented") + + arg_names = [] + arg_types = [] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif hints.get((i, ), None) == 1: + arg_names.append(arg_name) + arg_types.append("i32") + + # dump C stub code + suffix = '' + for i, ty in enumerate(signature.values()): + if hints.get((i, ), None) == 1: + suffix += f'{i}c' + if hints.get((i, ), None) == 16: + suffix += f'{i}d' + func_name = '_'.join([out_name, sig_hash, suffix]) + asm = ccinfo.asm[backend.binary_ext] # store binary data once + + hex_ = str(binascii.hexlify(asm))[2:-1] + + ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type + backend_name = target.backend + + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(asm), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]), + "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": "_".join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + "warp_size": target.warp_size, + "backend_name": backend_name, + } + output_files = [] + template_dir = Path(__file__).parent / "extra" / backend_name + for template_path in template_dir.glob('compile.*'): + ext = template_path.suffix + output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}") + with output_file.open("w") as fp: + fp.write(template_path.read_text().format(**params)) + output_files.append(output_file) + + return func_name, output_files + + +if __name__ == "__main__": + main() diff --git a/third_party/mthreads/python/triton/tools/disasm.py b/third_party/mthreads/python/triton/tools/disasm.py new file mode 100644 index 0000000000..c2301fd2ea --- /dev/null +++ b/third_party/mthreads/python/triton/tools/disasm.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def path_to_cuobjdump(): + from triton import knobs + return knobs.nvidia.cuobjdump.path + + +def extract(file_path, fun): + cuobjdump = path_to_cuobjdump() + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/mthreads/python/triton/tools/experimental_descriptor.py b/third_party/mthreads/python/triton/tools/experimental_descriptor.py new file mode 100644 index 0000000000..0803684c80 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/experimental_descriptor.py @@ -0,0 +1,52 @@ +import torch + +from triton.tools.tensor_descriptor import TensorDescriptor + + +class _RawPointerTensor: + + def __init__(self, ptr: int, dtype: torch.dtype, device: str = "musa"): + self._ptr = int(ptr) + self.dtype = dtype + self.device = torch.device(device) + + def data_ptr(self) -> int: + return self._ptr + + def element_size(self) -> int: + return int(self.dtype.itemsize) + + +def _dtype_from_element_size(element_size: int) -> torch.dtype: + if element_size == 1: + return torch.uint8 + if element_size == 2: + return torch.int16 + if element_size == 4: + return torch.int32 + raise ValueError(f"unsupported descriptor element_size={element_size}") + + +def _contiguous_strides(shape): + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + +def _create_descriptor(ptr, shape, block_shape, element_size): + dtype = _dtype_from_element_size(int(element_size)) + base = _RawPointerTensor(ptr, dtype=dtype, device="musa") + return TensorDescriptor(base, list(shape), _contiguous_strides(shape), list(block_shape)) + + +def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): + return _create_descriptor(ptr, (dim, ), (block_dim, ), element_size) + + +def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): + return _create_descriptor(ptr, (dim1, dim0), (block_dim1, block_dim0), element_size) + + +def create_3d_tma_descriptor(ptr, dim2, dim1, dim0, block_dim2, block_dim1, block_dim0, element_size): + return _create_descriptor(ptr, (dim2, dim1, dim0), (block_dim2, block_dim1, block_dim0), element_size) diff --git a/third_party/mthreads/python/triton/tools/link.py b/third_party/mthreads/python/triton/tools/link.py new file mode 100644 index 0000000000..9c070160ff --- /dev/null +++ b/third_party/mthreads/python/triton/tools/link.py @@ -0,0 +1,335 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]*)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + # [backend_name] + self.backend_name_re = re.compile("//[\\s]*tt-linker-backend:[\\s]*([\\w]+)") + + self.kernels = defaultdict(list) + self.backend_name = None + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + else: + m = self.backend_name_re.match(ln) + if _exists(m): + backend_name = m.group(1) + if self.backend_name is None: + self.backend_name = backend_name + elif self.backend_name != backend_name: + raise RuntimeError(f"differing backend {self.backend_name} vs. {backend_name}") + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, suffix only includes indexes followed by d or c. + for i in range(len(args)): + pos = 0 + idx_matched = suffix.startswith(str(i)) + if not idx_matched: + continue + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + suffix = suffix[pos:] + + if len(suffix) > 0: + raise Exception(f"Has invalid extra suffix: {suffix}") + sizes.extend([None] * (len(args) - len(sizes))) + + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)}); +TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"TT_ResultTy {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(TT_StreamTy stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"((uintptr_t){val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return TT_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef TT_ResultTy (*kernel_func_t)(TT_StreamTy stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + backend_prelude = (Path(__file__).parent / "extra" / parser.backend_name / "link.h").read_text() + with args.out.with_suffix(".h").open("w") as fp: + out = backend_prelude + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = backend_prelude + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/mthreads/python/triton/tools/mxfp.py b/third_party/mthreads/python/triton/tools/mxfp.py new file mode 100644 index 0000000000..1b129c1aef --- /dev/null +++ b/third_party/mthreads/python/triton/tools/mxfp.py @@ -0,0 +1,301 @@ +""" +Helper classes for working with low precision floating point types that +align with the opencompute (OCP) microscaling (MX) specification. + * MXFP4Tensor: 4-bit E2M1 floating point data + * MXScaleTensor: 8-bit E8M0 floating point data +Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +""" + +import torch + + +class MXFP4Tensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with four bit E2M1 floating point data as defined by the + opencompute microscaling specification. + + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self): + S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device) + M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + + self.data = ((S << 3) | (E << 1) | M).type(torch.uint8) + return self + + def to(self, dtype): + """ + Convert fp4e2m1 data to float32. + + Returns: + - A torch tensor of type dtype representing the fp4e2m1 data. + """ + assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion" + + data = self.data + S = ((data >> 3) & 0x1).type(dtype) + E = ((data >> 1) & 0x3).type(dtype) + M = (data & 0x1).type(dtype) + + # The MXF4 E2M1 spec defines 0bS000 as zero + value = torch.zeros_like(S) + is_zero = (E == 0) & (M == 0) + non_zero_mask = ~is_zero + if non_zero_mask.any(): + S_nz = S[non_zero_mask] + E_nz = E[non_zero_mask] + M_nz = M[non_zero_mask] + + sign = torch.pow(-1, S_nz) + # Normal and subnormal handling for the exponent and mantissa + exponent = torch.where(E_nz == 0, E_nz, E_nz - 1) + mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5) + value_nz = sign * torch.pow(2, exponent) * mantissa + + value[non_zero_mask] = value_nz + + # For zeros, the values must remain zero with the correct sign + value[is_zero & (S == 1)] *= -1 + return value.type(torch.float32) + + def _from_float(self, values): + """ + Convert float32 numbers to mxf4 e2m1 format. + * No encodings are reserved for Inf or NaN in mxf4. + * Conversion from float supports roundTiesToEven rounding mode. + * If a value exceeds the mxf4 representable range after rounding, + clamps to the maximum mxf4 magnitude, preserving the sign. + * If a value has magnitude less than the minimum subnormal magnitude + in mxf4 after rounding, converts to zero. + + Parameters: + - values: A torch tensor of float32 numbers to convert to fp4 format. + """ + S = torch.signbit(values).type(torch.uint8) + abs_values = torch.abs(values) + + is_zero = (abs_values == 0) + is_invalid = torch.isnan(values) | torch.isinf(values) + + # Enumerate all possible E2M1 exponent and mantissa values. We will + # use these to compare the distance between float32 and all possible + # E2M1 floats to find the nearest E2M1 representable value + E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device) + M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device) + + candidate_values = [] + candidate_E = [] + candidate_M = [] + + for E in E_bits: + if E == 0: + # Subnormals + exponent = 0 + for M in M_bits: + significand = M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + else: + # Normals + exponent = E.item() - 1 + for M in M_bits: + significand = 1.0 + M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + + candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device) + candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device) + candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device) + + abs_values_flat = abs_values.view(-1) + N = abs_values_flat.shape[0] + abs_values_expanded = abs_values_flat.unsqueeze(1) + + # Clamp invalid values to the max e2m1 representable value + max_candidate_value = candidates.max().item() + abs_values_flat[is_invalid.view(-1)] = max_candidate_value + + # Compute distance between all abs_values and candidate e2m1 values + errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0)) + + # To implement roundTiesToEven, we need to break ties by preferring + # even mantissas (M == 0). We do so by adding an epsilon bias to shift + # the closest candidate with an even mantissa closer to the float value + min_errors, _ = torch.min(errors, dim=1, keepdim=True) + is_tie = (errors == min_errors) + # More than one candidate has the min error for some float value + if is_tie.sum() > 1: + M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1) + tie_breaker = (M_bits_expanded == 0).type(torch.int32) + + errors = errors - (tie_breaker * 1e-6) + + best_indices = torch.argmin(errors, dim=1) + + E_selected = candidate_E[best_indices] + M_selected = candidate_M[best_indices] + E = E_selected.view(abs_values.shape) + M = M_selected.view(abs_values.shape) + + E[is_zero] = 0 + M[is_zero] = 0 + + return ((S << 3) | (E << 1) | M).type(torch.uint8) + + def to_packed_tensor(self, dim): + """ + Packs two e2m1 elements into a single uint8 along the specified dimension. + + Parameters: + - dim: The dimension along which to pack the elements. + + Returns: + - A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8. + """ + data = self.data + assert 0 <= dim < data.ndim, \ + "The dimension to pack along is not within the range of tensor dimensions" + + size_along_dim = data.size(dim) + new_size_along_dim = (size_along_dim + 1) // 2 + + # If the size is odd, we pad the data along dim with zeros at the end + if size_along_dim % 2 != 0: + pad_sizes = [0] * (2 * data.ndim) + pad_index = (data.ndim - dim - 1) * 2 + 1 + pad_sizes[pad_index] = 1 + data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0) + + new_shape = list(data.shape) + new_shape[dim] = new_size_along_dim + new_shape.insert(dim + 1, 2) # packed dimension of length 2 + data = data.reshape(*new_shape) + + low = data.select(dim + 1, 0) + high = data.select(dim + 1, 1) + packed = (high << 4) | low + + return packed + + def unpack_packed_tensor(self, packed_tensor, dim, original_shape): + """ + Unpacks a tensor where two fp4 elements are packed into a single uint8. + + Parameters: + - packed_tensor: The packed tensor + - dim: The dimension along which the tensor was packed. + - original_shape: The shape of the original tensor before packing. + + Returns: + - A tensor with the original data unpacked into uint8 elements containing one + fp4e2m1 element in the least significant bits. + """ + high = (packed_tensor >> 4) & 0xF + low = packed_tensor & 0xF + + stacked = torch.stack((low, high), dim=dim + 1) + + # Flatten along dim and dim+1 and then merge + shape = list(stacked.shape) + new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:] + data = stacked.reshape(*new_shape) + + # Remove any padding + if original_shape[dim] % 2 != 0: + indices = [slice(None)] * data.ndim + indices[dim] = slice(0, original_shape[dim]) + data = data[tuple(indices)] + + return data.type(torch.uint8) + + +class MXScaleTensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with microscaling E8M0 block scale factors. + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self, low=None, high=None): + """ + Generate random E8M0 data within a specified range. + * Excludes the NaN encoding (255). + """ + bias = 127 + + min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias) + max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias)) + assert min_exponent <= max_exponent, "Low must be less than or equal to high" + + E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device) + self.data = E + return self + + def to(self, dtype): + assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion" + data = self.data.type(dtype) + is_nan = (data == 255) + e_biased = data.clone() + e_biased[is_nan] = 0 + e = e_biased - 127 + value = torch.pow(2.0, e) + value[is_nan] = torch.nan + return value.type(dtype) + + def _from_float(self, values): + """ + Convert float32 numbers to E8M0 format. + * Values <= 0, NaNs, and Infs are converted to the NaN encoding (255). + * Positive values are converted by computing the floor of log2(value) to get the exponent. + + Parameters: + - values: A torch tensor of float32 numbers to convert to E8M0 format. + """ + result = torch.empty_like(values, dtype=torch.uint8, device=self.device) + + is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0) + result[is_invalid] = 255 + + valid_values = values[~is_invalid] + e = torch.floor(torch.log2(valid_values)) + e_biased = e + 127 + e_biased_int = e_biased.type(torch.int32) + e_biased_clamped = torch.clamp(e_biased_int, 0, 254) + result[~is_invalid] = e_biased_clamped.type(torch.uint8) + + return result diff --git a/third_party/mthreads/python/triton/tools/ragged_tma.py b/third_party/mthreads/python/triton/tools/ragged_tma.py new file mode 100644 index 0000000000..c7730e182a --- /dev/null +++ b/third_party/mthreads/python/triton/tools/ragged_tma.py @@ -0,0 +1,108 @@ +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# fmt: off + + +def create_ragged_descriptor(T, block_shape, ragged_dim=0): + """ + Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor' + which behaves like a concatenation (along the first axis) of subarrays + of potentially unequal size. + + The load_ragged and store_ragged device functions can be used to read + and write from subarrays T[slice_off : slice_off + slice_size] + with hardware bounds-checking preventing any sort of leakage outside + the subarray. + """ + + block_shape = list(block_shape) + tensor_shape = list(T.shape) + rank = len(tensor_shape) + + if ragged_dim < 0: + ragged_dim += rank + + assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged" + assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions" + + assert len(block_shape) == rank, "block shape must have same length as tensor shape" + + max_int = 0x7fff0000 + billion = 0x40000000 # == 2**30 + + assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30" + tensor_shape[ragged_dim] = billion + ragged_stride = T.stride(ragged_dim) + + # we prepend an extra two dimensions and rely on the fact that pointers + # have 64-bit wraparound semantics: + tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)] + tma_shape = [max_int, max_int] + tensor_shape + box_shape = [1, 1] + block_shape + + return TensorDescriptor(T, tma_shape, tma_stride, box_shape) + + +@triton.jit +def to_ragged_indices(slice_off, slice_size, row): + """ + Helper function for load_ragged and store_ragged. + """ + + billion = 0x40000000 # == 2**30 + x = billion - slice_size + row + y = slice_off + slice_size + + return billion, y, x + + +@triton.jit +def load_ragged(TMA, slice_off, slice_size, coords, ragged_dim: tl.constexpr = 0): + """ + Read from a subarray T[slice_off : slice_off + slice_size] with + hardware bounds-checking, where reading outside the subarray gives zeros. + + Coords should be an appropriately-sized list of integers, just like in + TMA.load(). + """ + + tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor") + + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) + data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:]) + data = tl.reshape(data, data.shape[2:]) + return data + + +@triton.jit +def store_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): + """ + Write to a subarray T[slice_off : slice_off + slice_size] with + hardware bounds-checking, where writes outside the subarray are masked + correctly. + + Coords should be an appropriately-sized list of integers, just like in + TMA.store(). + """ + + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) + data = tl.reshape(data, [1, 1] + data.shape) + TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) + + +@triton.jit +def atomic_add_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0): + """ + Atomic add into a subarray T[slice_off : slice_off + slice_size] with + hardware bounds-checking, where adds outside the subarray are masked + correctly. + + Coords should be an appropriately-sized list of integers, just like in + TMA.atomic_add(). + """ + + c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim]) + data = tl.reshape(data, [1, 1] + data.shape) + TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data) diff --git a/third_party/mthreads/python/triton/tools/tensor_descriptor.py b/third_party/mthreads/python/triton/tools/tensor_descriptor.py new file mode 100644 index 0000000000..21c359aa30 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/tensor_descriptor.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import List, Any +from triton._utils import validate_block_shape + + +@dataclass +class TensorDescriptor: + base: Any + shape: List[int] + strides: List[int] + block_shape: List[int] + padding: str = "zero" + + def __post_init__(self): + rank = len(self.shape) + assert len(self.strides) == rank, f"rank mismatch: {self}" + assert len(self.block_shape) == rank, f"rank mismatch: {self}" + assert rank > 0, "rank must not be zero" + assert rank <= 5, "rank cannot be more than 5" + ty = type(self.base) + if ty.__name__ not in ("FakeTensor", "FunctionalTensor"): + assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned" + validate_block_shape(self.block_shape) + elem_bytes = self.base.dtype.itemsize + for stride in self.strides[:-1]: + assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned" + for shape_dim in self.shape: + assert shape_dim > 0, "shape must be positive" + assert self.strides[-1] == 1, "Last dimension must be contiguous" + assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding" + if self.padding == "nan": + assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors" + + @staticmethod + def from_tensor(tensor: Any, block_shape: List[int], padding="zero"): + return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding) diff --git a/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator.py b/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator.py new file mode 100644 index 0000000000..0fe9106fb3 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator.py @@ -0,0 +1,383 @@ +# Experimental Triton to Gluon AST translator. +# This file takes a Triton JIT entry point and generates a Gluon equivalent including all +# its dependencies. This generates highly inefficient Gluon code and is only used for +# functional testing. +# +import ast +from typing import Optional +import triton +import triton.language.core as tlc +import triton.experimental.gluon.language as ttgl +import sys +import importlib +import importlib.util +import copy + +GLUON_IMPORT_LINES = ("from triton.experimental import gluon\n" + "from triton.experimental.gluon import language as ttgl\n" + "from triton.tools.triton_to_gluon_translater.translator_helpers import *\n") + + +class TritonToGluonTransformer(ast.NodeTransformer): + """Transforms Triton kernel source into a functionally equivalent Gluon source. + + This transformer rewrites builtins, dtype/tensor attributes, constexpr annotations, + and records nested JIT callables to be converted and appended to the output. + """ + + def __init__(self, globals_map: dict, shared_jit_set: set, shared_queue: list, is_jit, constexpr_globals: dict): + super().__init__() + # Resolution scope (globals ∪ nonlocals) + self.scope: dict = globals_map or {} + # Track discovered JIT functions to inline/append later + self.jit_functions: set = shared_jit_set + self.queue: list = shared_queue + self.is_jit = is_jit + # Maps module_file -> {name: value} to pull constexpr globals from the original source code + self.constexpr_globals: dict = constexpr_globals + + def is_triton_constexpr_annotation(self, ann: ast.expr) -> bool: + # Resolve the annotation to a Python object and compare by identity + obj = self.resolve_value(ann) + return obj is tlc.constexpr + + def as_ttgl_constexpr(self) -> ast.expr: + # Build ttgl.constexpr + return self.ttgl_attr("constexpr") + + def maybe_rewrite_constexpr_annotation(self, ann: Optional[ast.expr]) -> Optional[ast.expr]: + if ann is None: + return None + if self.is_triton_constexpr_annotation(ann): + return self.as_ttgl_constexpr() + return ann + + def ttgl_attr(self, name: str) -> ast.AST: + return ast.Attribute(value=ast.Name(id="ttgl", ctx=ast.Load()), attr=name, ctx=ast.Load()) + + def resolve_value(self, expr: ast.expr): + if isinstance(expr, ast.Name): + value = self.scope.get(expr.id) or sys.modules.get(expr.id) + return value + if isinstance(expr, ast.Attribute): + base = self.resolve_value(expr.value) + if base is None: + return None + return getattr(base, expr.attr, None) + return None + + def forward_call(self, node: ast.Call, target_func: ast.expr, filter_keywords: list[str] = []) -> ast.Call: + new_keywords = [kw for kw in node.keywords if kw.arg not in filter_keywords] + return ast.Call(func=target_func, args=list(node.args), keywords=list(new_keywords)) + + def visit_Call(self, node: ast.Call) -> ast.AST: + node = self.generic_visit(node) + resolved_callable = self.resolve_value(node.func) + if resolved_callable is not None: + resolved_callable = triton.language.core._unwrap_if_constexpr(resolved_callable) + base_function = getattr(resolved_callable, "fn", resolved_callable) + function_name = getattr(base_function, "__qualname__", getattr(base_function, "__name__", + str(base_function))) + if triton.language.core.is_builtin(resolved_callable): + builtin_name = function_name.split(".")[-1] + builtin_mapping: dict[str, ast.expr] = { + "arange": ast.Name(id="tl_arange", ctx=ast.Load()), + "full": ast.Name(id="tl_full", ctx=ast.Load()), + "trans": ast.Name(id="tl_trans", ctx=ast.Load()), + "dot": ast.Name(id="tl_dot", ctx=ast.Load()), + "dot_scaled": ast.Name(id="tl_dot_scaled", ctx=ast.Load()), + "make_tensor_descriptor": ast.Name(id="tl_make_tensor_descriptor", ctx=ast.Load()), + "load_tensor_descriptor": ast.Name(id="tl_load_tensor_descriptor", ctx=ast.Load()), + "store_tensor_descriptor": ast.Name(id="tl_store_tensor_descriptor", ctx=ast.Load()), + "num_threads": ast.Name(id="get_num_threads_per_program", ctx=ast.Load()), + } + mapped_target = builtin_mapping.get(builtin_name) + if mapped_target is None and hasattr(ttgl, builtin_name): + mapped_target = self.ttgl_attr(builtin_name) + + filter_keywords = [] + # for reshape drop the can_reorder keyword, it is just an optimization and doesn't help much in Gluon. + if builtin_name == "reshape": + filter_keywords = ["can_reorder"] + if mapped_target is not None: + node = self.forward_call(node, mapped_target, filter_keywords) + # For split, apply on the source argument rather than wrapping destination + if builtin_name == "split": + source_arg = node.args[0] + wrapped_src = ast.Call(func=ast.Name(id="set_split_src_layout", ctx=ast.Load()), + args=[source_arg], keywords=[]) + node.args[0] = ast.copy_location(wrapped_src, source_arg) + # For shape/layout changing ops, wrap to reset layout + if builtin_name in {"reshape", "trans", "permute", "join", "reduce", "split"}: + reset_layout_wrapped = ast.Call(func=ast.Name(id="reset_to_default_layout", ctx=ast.Load()), + args=[node], keywords=[]) + node = ast.copy_location(reset_layout_wrapped, node) + return node + # Track JITFunction callees + if isinstance(resolved_callable, triton.runtime.jit.JITCallable): + if resolved_callable not in self.jit_functions: + self.jit_functions.add(resolved_callable) + self.queue.append(resolved_callable) + # Strip namespace: rewrite to local function name + return self.forward_call(node, ast.Name(id=getattr(base_function, "__name__", ""), ctx=ast.Load())) + if resolved_callable is triton.language.core.range: + # skip all keywords except arg1, arg2, and step and replace with range. + allowed = {"arg1", "arg2", "step"} + new_keywords = [kw for kw in node.keywords if kw.arg in allowed] + new_args = list(node.args[:3]) + return ast.copy_location( + ast.Call(func=ast.Name(id="range", ctx=ast.Load()), args=new_args, keywords=new_keywords), + node, + ) + if resolved_callable is triton.language.core.static_range: + return self.forward_call(node, self.ttgl_attr("static_range")) + else: + if isinstance(node.func, ast.Attribute) and node.func.attr in ["store", "load", "gather", "scatter"]: + helper_name = "tl_obj_" + node.func.attr + return ast.Call( + func=ast.Name(id=helper_name, ctx=ast.Load()), + args=[node.func.value] + list(node.args), + keywords=list(node.keywords), + ) + if isinstance(node.func, + ast.Attribute) and node.func.attr in ["reshape", "trans", "split", "join", "reduce"]: + if node.func.attr == "split": + receiver_expr = node.func.value + wrapped_receiver = ast.Call(func=ast.Name(id="set_split_src_layout", ctx=ast.Load()), + args=[receiver_expr], keywords=[]) + new_func = ast.Attribute(value=ast.copy_location(wrapped_receiver, receiver_expr), + attr=node.func.attr, ctx=ast.Load()) + node = ast.copy_location( + ast.Call(func=new_func, args=list(node.args), keywords=list(node.keywords)), node) + wrapped = ast.Call( + func=ast.Name(id="reset_to_default_layout", ctx=ast.Load()), + args=[node], + keywords=[], + ) + return ast.copy_location(wrapped, node) + return node + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + node = self.generic_visit(node) + last_part = node.attr + # Only rewrite dtypes when the resolved object is a tl.dtype instance + # or the tl.dtype class itself (e.g., tl.float16 or tl.dtype.float16 / tl.dtype) + resolved_obj = self.resolve_value(node) + if resolved_obj is not None: + if isinstance(resolved_obj, tlc.dtype): + return self.ttgl_attr(last_part) + if resolved_obj is tlc.dtype and last_part == "dtype": + return self.ttgl_attr("dtype") + if resolved_obj is tlc.tensor and last_part == "tensor": + return self.ttgl_attr("tensor") + if resolved_obj is tlc.constexpr and last_part == "constexpr": + return self.ttgl_attr("constexpr") + if last_part == "tensor_descriptor": + return self.ttgl_attr("nvidia.hopper.tma.tensor_descriptor") + return node + + def visit_Name(self, node): + node = self.generic_visit(node) + resolved_obj = self.resolve_value(node) + if resolved_obj is not None: + # Track standalone references to JITCallable and normalize name + if isinstance(resolved_obj, triton.runtime.jit.JITCallable): + if resolved_obj not in self.jit_functions: + self.jit_functions.add(resolved_obj) + self.queue.append(resolved_obj) + base_function = getattr(resolved_obj, "fn", resolved_obj) + normalized_name = getattr(base_function, "__name__", + getattr(base_function, "__qualname__", getattr(node, "id", ""))) + return ast.copy_location(ast.Name(id=normalized_name, ctx=node.ctx), node) + if isinstance(resolved_obj, triton.language.core.constexpr): + identifier = getattr(node, "id", None) + if identifier is not None: + # Use the current capture scope's file for the defining module + module_file = self.scope.get("__file__") + if isinstance(module_file, str): + bucket = self.constexpr_globals.setdefault(module_file, {}) + bucket[identifier] = resolved_obj + return node + + def visit_Subscript(self, node: ast.Subscript) -> ast.AST: + node = self.generic_visit(node) + # TODO: generalize to + # For patterns like x[None, :] or x[:, None], ensure x has a SliceLayout along the expanded dim + expanded_dim = None + if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2: + first, second = node.slice.elts + if isinstance(first, ast.Constant) and first.value is None: + expanded_dim = 0 + elif isinstance(second, ast.Constant) and second.value is None: + expanded_dim = 1 + if expanded_dim is not None: + value_expr = node.value + # Construct a 2D parent shape with a dummy dimension of size 1 at the expanded dim + # Use value.type.shape[0] as the vector length + type_attr = ast.Attribute(value=value_expr, attr="type", ctx=ast.Load()) + shape_attr = ast.Attribute(value=type_attr, attr="shape", ctx=ast.Load()) + len_expr = ast.Subscript(value=shape_attr, slice=ast.Constant(value=0), ctx=ast.Load()) + if expanded_dim == 0: + parent_shape = ast.List(elts=[len_expr, ast.Constant(value=1)], ctx=ast.Load()) + else: + parent_shape = ast.List(elts=[ast.Constant(value=1), len_expr], ctx=ast.Load()) + # Build SliceLayout(dim, default_blocked_layout(parent_shape, ttgl.num_warps())) + slice_layout = ast.Call( + func=self.ttgl_attr("SliceLayout"), + args=[ + ast.Constant(value=expanded_dim), + ast.Call( + func=ast.Name(id="default_blocked_layout", ctx=ast.Load()), + args=[parent_shape, + ast.Call(func=self.ttgl_attr("num_warps"), args=[], keywords=[])], + keywords=[], + ), + ], + keywords=[], + ) + converted_value = ast.Call( + func=self.ttgl_attr("convert_layout"), + args=[value_expr, slice_layout], + keywords=[], + ) + return ast.Subscript(value=converted_value, slice=node.slice, ctx=node.ctx) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + # Rewrite parameter annotations: triton.language.constexpr -> ttgl.constexpr + # Positional-only and regular args + for arg in list(getattr(node.args, "posonlyargs", [])) + list(node.args.args): + arg.annotation = self.maybe_rewrite_constexpr_annotation(arg.annotation) + # Vararg / kwarg + if node.args.vararg is not None: + node.args.vararg.annotation = self.maybe_rewrite_constexpr_annotation(node.args.vararg.annotation) + if node.args.kwarg is not None: + node.args.kwarg.annotation = self.maybe_rewrite_constexpr_annotation(node.args.kwarg.annotation) + # Keyword-only args + for arg in node.args.kwonlyargs: + arg.annotation = self.maybe_rewrite_constexpr_annotation(arg.annotation) + if self.is_jit: + node.decorator_list.insert( + 0, ast.Attribute(value=ast.Name(id="gluon", ctx=ast.Load()), attr="jit", ctx=ast.Load())) + else: + node.decorator_list.insert( + 0, ast.Attribute(value=ast.Name(id="gluon", ctx=ast.Load()), attr="constexpr_function", ctx=ast.Load())) + # Process body + return self.generic_visit(node) + + +def unparse_original_assignments(constexpr_globals: dict) -> list[str]: + """Reconstruct original assignments for captured constexpr globals. + + We parse each defining module once to extract assignments, and rewrite tl.constexpr + calls to ttgl.constexpr so the generated code remains consistent. + """ + + # Build assignment strings for captured globals by parsing each module once. + def collect_names(target_node, names_out): + if isinstance(target_node, ast.Name): + names_out.append(target_node.id) + elif isinstance(target_node, (ast.Tuple, ast.List)): + for element in target_node.elts: + collect_names(element, names_out) + + def parse_assigns_and_imports(path: str) -> tuple[dict[str, ast.AST], dict[str, str]]: + try: + with open(path, "r") as f: + module_ast = ast.parse(f.read()) + except Exception: + return {}, {} + assigns: dict[str, ast.AST] = {} + imports: dict[str, str] = {} + for stmt in getattr(module_ast, "body", []): + if isinstance(stmt, ast.Assign): + names: list[str] = [] + for target in stmt.targets: + collect_names(target, names) + for identifier in names: + assigns[identifier] = stmt + elif isinstance(stmt, ast.AnnAssign): + names: list[str] = [] + collect_names(stmt.target, names) + if stmt.value is not None: + for identifier in names: + assigns[identifier] = stmt + elif isinstance(stmt, ast.ImportFrom) and stmt.level == 0 and isinstance(stmt.module, str): + for alias in stmt.names: + alias_name = alias.asname or alias.name.split(".")[-1] + imports[alias_name] = stmt.module + return assigns, imports + + def rewrite_constexpr_to_ttgl(node: ast.AST) -> ast.AST: + + class ConstexprToTtglRewriter(ast.NodeTransformer): + + def visit_Call(self, call_node: ast.Call) -> ast.AST: + call_node = self.generic_visit(call_node) + if isinstance(call_node.func, ast.Attribute) and call_node.func.attr == "constexpr": + call_node.func = ast.copy_location( + ast.Attribute(value=ast.Name(id="ttgl", ctx=ast.Load()), attr="constexpr", ctx=ast.Load()), + call_node.func) + return call_node + + return ConstexprToTtglRewriter().visit(node) + + results: list[str] = [] + imported_cache: dict[str, dict[str, ast.AST]] = {} + for mod_file, name_to_obj in constexpr_globals.items(): + assigns, imports = parse_assigns_and_imports(mod_file) + for identifier in sorted(name_to_obj.keys()): + node = assigns.get(identifier) + if node is None: + imported_module_name = imports.get(identifier) + if imported_module_name: + try: + module_spec = importlib.util.find_spec(imported_module_name) + origin = getattr(module_spec, "origin", None) if module_spec is not None else None + except Exception: + origin = None + if origin: + assignment_map = imported_cache.get(origin) + if assignment_map is None: + assignment_map, _ = parse_assigns_and_imports(origin) + imported_cache[origin] = assignment_map + node = assignment_map.get(identifier) + if node is not None: + edited_node = rewrite_constexpr_to_ttgl(copy.deepcopy(node)) + ast.fix_missing_locations(edited_node) + results.append(ast.unparse(edited_node)) + else: + results.append(f"{identifier} = {repr(name_to_obj[identifier])}") + return results + + +def convert_triton_to_gluon(src: list[triton.runtime.jit.JITCallable]) -> str: + """Convert a Triton JIT entry point into a Gluon source string.""" + shared_jit_set: set = set() + function_queue: list = list(src) + constexpr_globals: dict = {} + out = "" + # Process discovered callee JITFunctions, converting and appending them + while function_queue: + callee = function_queue.pop(0) + callee_src = callee._src + callee_tree = ast.parse(callee_src) + callee_scope = getattr(callee, "__globals__", {}) or {} + jit = isinstance(callee, triton.runtime.JITFunction) + callee_transformer = TritonToGluonTransformer(globals_map=callee_scope, shared_jit_set=shared_jit_set, + shared_queue=function_queue, is_jit=jit, + constexpr_globals=constexpr_globals) + callee_new = callee_transformer.visit(callee_tree) + ast.fix_missing_locations(callee_new) + out += "\n\n" + ast.unparse(callee_new) + + out = "\n\n" + out + + # Pull constexpr globals from the original source code + for line in unparse_original_assignments(constexpr_globals): + out = line + "\n" + out + + # Prepend required Gluon imports + out = GLUON_IMPORT_LINES + "\n\n" + out + + return out diff --git a/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator_helpers.py b/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator_helpers.py new file mode 100644 index 0000000000..2b946ee3bf --- /dev/null +++ b/third_party/mthreads/python/triton/tools/triton_to_gluon_translater/translator_helpers.py @@ -0,0 +1,618 @@ +from triton.experimental import gluon +from triton.experimental.gluon import language as ttgl +from triton.experimental.gluon.language.nvidia.hopper import mbarrier +from triton.experimental.gluon.language.nvidia.blackwell import ( + TensorMemoryLayout, + TensorMemoryScalesLayout, + allocate_tensor_memory, + get_tmem_reg_layout, + tcgen05_mma, + tcgen05_mma_scaled, + tcgen05_commit, +) +from triton.experimental.gluon.language.nvidia.ampere import mma_v2 +from triton.experimental.gluon.language.nvidia.hopper import tma, fence_async_shared +from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell + + +@gluon.constexpr_function +def tl_dot_mma_sync_layout(shape, num_warps): + rank = len(shape) + assert rank in [2, 3], "MMA sync only supports 2D shapes or 3D shapes with a batch outer dimension" + if rank == 2: + return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1], instr_shape=[16, 8]) + return ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[num_warps, 1, 1], instr_shape=[1, 16, 8]) + + +@gluon.constexpr_function +def tl_dot_mma_sync_k_width(a_ty, b_ty): + a_bitwidth = a_ty.element_ty.primitive_bitwidth + b_bitwidth = b_ty.element_ty.primitive_bitwidth + min_bitwidth = min(a_bitwidth, b_bitwidth) + return max(32 // min_bitwidth, 1) + + +@gluon.jit +def tl_dot_mma_sync(a, b, acc_init=None, input_precision=None, out_dtype=ttgl.float32): + mma_layout: ttgl.constexpr = tl_dot_mma_sync_layout(a.type.shape, ttgl.num_warps()) + k_width: ttgl.constexpr = tl_dot_mma_sync_k_width(a.type, b.type) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=0, k_width=k_width) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=1, k_width=k_width) + a = ttgl.convert_layout(a, a_layout) + b = ttgl.convert_layout(b, b_layout) + if acc_init is not None: + acc = ttgl.convert_layout(acc_init, mma_layout) + else: + acc = ttgl.full([a.shape[0], a.shape[1], b.shape[2]], 0.0, out_dtype, layout=mma_layout) + result = mma_v2(a, b, acc, input_precision) + if acc_init is not None: + result = ttgl.convert_layout(result, acc_init.type.layout) + return result + + +@gluon.constexpr_function +def tl_dot_mmav5_supported(a_ty, b_ty, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): + assert max_num_imprecise_acc is None, "max_num_imprecise_acc only applies to Hopper warp_group_dot" + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None and (allow_tf32 or allow_tf32 is None): + input_precision = "tf32" + + M = a_ty.shape[0] + N = b_ty.shape[1] + K = a_ty.shape[1] + min_K = 256 // a_ty.element_ty.primitive_bitwidth + if a_ty.element_ty.is_int() or b_ty.element_ty.is_int(): + return False + if min(a_ty.element_ty.primitive_bitwidth, b_ty.element_ty.primitive_bitwidth) >= 32 and input_precision != "tf32": + return False + return num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 64 and N >= 16 + + +@gluon.constexpr_function +def get_shared_memory_mma_layout(type, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): + if not allow_transpose: + if operand_index == 1: + transposed = True + else: + transposed = False + if force_transpose: + transposed = not transposed + else: + transposed = operand_index == 1 + + shape = type.shape + swizzle_byte_width = 0 + ele_bit_width = type.element_ty.primitive_bitwidth + packing_factor = 2 if is_fp4_padded else 1 + + contig_dim_size_in_byte = (shape[0] if transposed else shape[1]) * packing_factor * ele_bit_width // 8 + if contig_dim_size_in_byte >= 128 and contig_dim_size_in_byte % 128 == 0: + swizzle_byte_width = 128 + elif contig_dim_size_in_byte >= 64 and contig_dim_size_in_byte % 64 == 0: + swizzle_byte_width = 64 + elif contig_dim_size_in_byte >= 32 and contig_dim_size_in_byte % 32 == 0: + swizzle_byte_width = 32 + else: + swizzle_byte_width = 0 + + flatten_outer_dim = 1 + for dim in shape: + flatten_outer_dim *= dim + if len(shape) < 2 or flatten_outer_dim < 8: + swizzle_byte_width = 0 + return ttgl.NVMMASharedLayout(swizzle_byte_width=swizzle_byte_width, transposed=transposed, + element_bitwidth=ele_bit_width, rank=len(shape), fp4_padded=is_fp4_padded) + + +@gluon.jit +def get_shared_memory_mma_operand(value, operand_index, allow_transpose, is_fp4_padded=False, force_transpose=False): + layout: ttgl.constexpr = get_shared_memory_mma_layout(value.type, operand_index, allow_transpose, is_fp4_padded, + force_transpose) + return ttgl.allocate_shared_memory(value.dtype, value.shape, layout, value) + + +@gluon.jit +def tl_dot_blackwell(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, + out_dtype=ttgl.float32): + M: ttgl.constexpr = a.type.shape[0] + N: ttgl.constexpr = b.type.shape[1] + + allow_transpose = not a.type.element_ty.is_fp32() + a_smem = get_shared_memory_mma_operand(a, 0, allow_transpose) + b_smem = get_shared_memory_mma_operand(b, 1, allow_transpose) + + # MMA instruction shape + m: ttgl.constexpr = 128 if M >= 128 else 64 + n: ttgl.constexpr = 256 if N >= 256 else N + + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) + + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout(acc_dtype, (M, N), acc_tmem_layout, ttgl.num_warps()) + if acc is not None: + acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) + else: + acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) + acc_tmem = allocate_tensor_memory(acc_temp.dtype, [M, N], acc_tmem_layout, acc_temp) + fence_async_shared() + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + tcgen05_mma(a_smem, b_smem, acc_tmem, use_acc=True) + tcgen05_commit(bar) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + + # Load back from TMEM using a register layout and convert to acc layout + out = acc_tmem.load(tmem_reg_layout) + ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) + out = ttgl.convert_layout(out, ret_layout) + return out + + +@gluon.jit +def tl_dot(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=ttgl.float32): + num_warps: ttgl.constexpr = ttgl.num_warps() + if tl_dot_mmav5_supported(a.type, b.type, num_warps, input_precision, allow_tf32, max_num_imprecise_acc): + return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype) + else: + return tl_dot_mma_sync(a, b, acc, input_precision, out_dtype) + + +@gluon.constexpr_function +def tl_dot_scaled_mmav5_supported(a_ty, b_ty, num_warps): + M = a_ty.shape[0] + N = b_ty.shape[1] + K = a_ty.shape[1] + min_K = 256 // a_ty.element_ty.primitive_bitwidth + return num_warps in [4, 8] and len(a_ty.shape) == 2 and len(b_ty.shape) == 2 and K >= min_K and M >= 128 and N >= 16 + + +@gluon.constexpr_function +def get_swizzle_byte_width(bitwidth): + swizzle = min(bitwidth, 128) + swizzle = 0 if swizzle < 32 else swizzle + return swizzle + + +@gluon.constexpr_function +def get_int_type(bitwidth): + if bitwidth == 64: + return ttgl.int64 + elif bitwidth == 32: + return ttgl.int32 + elif bitwidth == 16: + return ttgl.int16 + elif bitwidth == 8: + return ttgl.int8 + else: + assert False, f"Unsupported bitwidth: {bitwidth}" + + +@gluon.jit +def tl_dot_decomposed_scale_to_16(scale, compute_type): + large_fp_type: ttgl.constexpr = ttgl.float32 if compute_type == ttgl.float16 else compute_type + int_width: ttgl.constexpr = large_fp_type.primitive_bitwidth + int_type: ttgl.constexpr = get_int_type(int_width) + + zexted = ttgl.cast(scale, int_type) + shift_value: ttgl.constexpr = large_fp_type.fp_mantissa_width + shl_res = zexted << shift_value + scale_fp = ttgl.cast(shl_res, large_fp_type, bitcast=True) + if large_fp_type != compute_type: + scale_fp = ttgl.cast(scale_fp, compute_type) + return scale_fp + + +@gluon.constexpr_function +def tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank): + shape = scale_ty.shape.values + [1] + blocked = default_blocked_layout(shape, num_warps) + slice = ttgl.SliceLayout(rank, blocked) + return slice + + +@gluon.constexpr_function +def tl_dot_get_permute_order(rank, dim): + order = list(range(rank)) + order.insert(dim + 1, rank) + return order + + +@gluon.constexpr_function +def tl_dot_get_reshape_shape(scale_ty, dim): + shape = list(scale_ty.shape.values) + shape.pop() + shape[dim] *= 32 + return shape + + +@gluon.jit +def tl_dot_decomposed_broadcast_scale(scale, dim): + scale_ty: ttgl.constexpr = scale.type + rank: ttgl.constexpr = len(scale_ty.shape) + + num_warps: ttgl.constexpr = ttgl.num_warps() + slice_enc: ttgl.constexpr = tl_dot_get_expand_dims_layout(scale_ty, num_warps, rank) + scale = ttgl.convert_layout(scale, slice_enc) + expand_scale = scale.expand_dims(rank) + broadcast_scale = expand_scale.broadcast_to(scale.type.shape + (32, )) + permute_order: ttgl.constexpr = tl_dot_get_permute_order(rank, dim) + transposed_scale = broadcast_scale.permute(permute_order.value) + reshape_shape: ttgl.constexpr = tl_dot_get_reshape_shape(broadcast_scale.type, dim) + return transposed_scale.reshape(reshape_shape) + + +@gluon.constexpr_function +def tl_dot_decomposed_get_transposed_order(rank): + assert rank >= 2 + order = list(range(rank - 2)) + order += [rank - 1, rank - 2] + return order + + +@gluon.jit +def tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index): + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if operand_index == 1: + order: ttgl.constexpr = tl_dot_decomposed_get_transposed_order(rank) + scale = ttgl.permute(scale, order.value) + + scale16 = tl_dot_decomposed_scale_to_16(scale, compute_type) + reshape_scale = tl_dot_decomposed_broadcast_scale(scale16, k_dim) + return ttgl.convert_layout(reshape_scale, v.type.layout), scale + + +@gluon.jit +def tl_dot_decomposed_mask_nan(mxfp, scale, fast_math): + ttgl.static_assert(fast_math, "TODO: support non-fast-math") + return mxfp + + +@gluon.jit +def tl_dot_decomposed_scale_arg(v, scale, arg_format, operand_index, compute_type, fast_math): + is_fp4: ttgl.constexpr = arg_format == "e2m1" + rank: ttgl.constexpr = len(v.type.shape) + k_dim: ttgl.constexpr = rank - 1 if operand_index == 0 else rank - 2 + + if is_fp4: + v = ttgl.fp4_to_fp(v, compute_type, k_dim) + else: + v = ttgl.cast(v, compute_type) + if scale is None: + return v + else: + reshape_scale, scale = tl_dot_decomposed_extend_and_broadcast_scale(v, scale, compute_type, operand_index) + mxfp = ttgl.mul(v, reshape_scale) + return tl_dot_decomposed_mask_nan(mxfp, scale, fast_math) + + +@gluon.jit +def tl_dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, + rhs_k_pack=True, out_dtype=ttgl.float32): + if tl_dot_scaled_mmav5_supported(lhs.type, rhs.type, + ttgl.num_warps() and lhs_scale is not None and rhs_scale is not None): + return tl_dot_scaled_blackwell(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, + lhs_k_pack, rhs_k_pack, out_dtype) + else: + return tl_dot_decomposed_block_scales(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, + lhs_k_pack, rhs_k_pack, out_dtype) + + +@gluon.jit +def tl_dot_decomposed_block_scales(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, + lhs_k_pack=True, rhs_k_pack=True, out_dtype=ttgl.float32): + if lhs_scale is None and rhs_scale is not None: + lhs_trans = tl_trans(lhs) + rhs_trans = tl_trans(rhs) + if acc is not None: + orig_layout: ttgl.constexpr = acc.type.layout + acc = tl_trans(acc) + result = tl_dot_scaled(rhs_trans, rhs_scale, rhs_format, lhs_trans, lhs_scale, lhs_format, acc, fast_math, + lhs_k_pack, rhs_k_pack, out_dtype) + result = tl_trans(result) + if acc is not None: + result = ttgl.convert_layout(result, orig_layout) + return result + else: + ttgl.static_assert(not (not lhs_k_pack or not rhs_k_pack), "TODO: support m/n packed formats") + compute_type: ttgl.constexpr = ttgl.float16 if (lhs_format == "fp16" or rhs_format == "fp16") else ttgl.bfloat16 + + scale_a = tl_dot_decomposed_scale_arg(lhs, lhs_scale, lhs_format, 0, compute_type, fast_math) + scale_b = tl_dot_decomposed_scale_arg(rhs, rhs_scale, rhs_format, 1, compute_type, fast_math) + + return tl_dot(scale_a, scale_b, acc, out_dtype=out_dtype) + + +@gluon.jit +def tl_dot_scaled_blackwell(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, + lhs_k_pack=True, rhs_k_pack=True, out_dtype=ttgl.float32): + is_a_fp4: ttgl.constexpr = lhs_format == "e2m1" + is_b_fp4: ttgl.constexpr = rhs_format == "e2m1" + + mixed_prec: ttgl.constexpr = lhs_format != rhs_format + is_a_mixed_prec_fp4: ttgl.constexpr = mixed_prec and is_a_fp4 + is_b_mixed_prec_fp4: ttgl.constexpr = mixed_prec and not is_a_fp4 and is_b_fp4 + + is_mmav5_fp4_padded_a: ttgl.constexpr = is_a_mixed_prec_fp4 or not lhs_k_pack + is_mmav5_fp4_padded_b: ttgl.constexpr = is_b_mixed_prec_fp4 or not rhs_k_pack + + a_smem = get_shared_memory_mma_operand(lhs, 0, allow_transpose=not is_a_fp4, is_fp4_padded=is_mmav5_fp4_padded_a, + force_transpose=not lhs_k_pack) + b_smem = get_shared_memory_mma_operand(rhs, 1, allow_transpose=not is_b_fp4, is_fp4_padded=is_mmav5_fp4_padded_b, + force_transpose=not rhs_k_pack) + + M: ttgl.constexpr = lhs.type.shape[0] + N: ttgl.constexpr = rhs.type.shape[1] + + m: ttgl.constexpr = 128 + n: ttgl.constexpr = 256 if N >= 256 else N + + acc_dtype: ttgl.constexpr = acc.dtype if acc is not None else out_dtype + col_stride: ttgl.constexpr = 32 // acc_dtype.primitive_bitwidth + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([m, n], col_stride=col_stride) + tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout(acc_dtype, (M, N), acc_tmem_layout, ttgl.num_warps()) + if acc is not None: + acc_temp = ttgl.convert_layout(acc, tmem_reg_layout) + else: + acc_temp = ttgl.zeros([M, N], out_dtype, layout=tmem_reg_layout) + acc_tmem = allocate_tensor_memory(acc_temp.dtype, [M, N], acc_tmem_layout, acc_temp) + fence_async_shared() + + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() + scale_layout_reg_lhs: ttgl.constexpr = get_tmem_reg_layout(lhs_scale.dtype, lhs_scale.type.shape, scale_layout, + ttgl.num_warps()) + scale_layout_reg_rhs: ttgl.constexpr = get_tmem_reg_layout(rhs_scale.dtype, rhs_scale.type.shape, scale_layout, + ttgl.num_warps()) + lhs_scale = ttgl.convert_layout(lhs_scale, scale_layout_reg_lhs) + rhs_scale = ttgl.convert_layout(rhs_scale, scale_layout_reg_rhs) + a_scale_tmem = allocate_tensor_memory(lhs_scale.dtype, lhs_scale.shape, scale_layout, lhs_scale) + b_scale_tmem = allocate_tensor_memory(rhs_scale.dtype, rhs_scale.shape, scale_layout, rhs_scale) + + tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, lhs_format, rhs_format, use_acc=True) + tcgen05_commit(bar) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + # Load back from TMEM using a register layout and convert to acc layout + out = acc_tmem.load(tmem_reg_layout) + ret_layout: ttgl.constexpr = default_blocked_layout([M, N], ttgl.num_warps()) + out = ttgl.convert_layout(out, ret_layout) + return out + + +@gluon.constexpr_function +def get_num_threads_per_warp() -> ttgl.constexpr: + return ttgl.constexpr(32) + + +@ttgl._core.builtin +def get_num_threads_per_program(_semantic=None, _generator=None): + return ttgl.num_warps(_semantic=_semantic, _generator=_generator) * get_num_threads_per_warp(_semantic=_semantic) + + +@gluon.constexpr_function +def default_blocked_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: + rank = len(shape) + # 1 element per thread for all dimensions + size_per_thread = [1 for _ in range(rank)] + # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) + threads_per_warp = [1 for _ in range(rank)] + # TODO: pick a better layout based on shape. Using this allows to not have to convert layout when broadcasting but may blow up register pressure. + threads_per_warp[rank - 1] = get_num_threads_per_warp() + # remaining_threads = get_num_threads_per_warp() + # for dim in range(rank - 1, -1, -1): + # threads_per_warp[dim] = min(remaining_threads, shape[dim]) + # remaining_threads = remaining_threads // threads_per_warp[dim] + # Use provided num_warps to distribute warps per CTA (put all on first dim) + warps_per_cta = [1 for _ in range(rank)] + warps_per_cta[0] = num_warps + # Natural order [rank-1, rank-2, ..., 0] + order = [i for i in range(rank - 1, -1, -1)] + return ttgl.BlockedLayout(size_per_thread=size_per_thread, threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, order=order) + + +@gluon.jit +def tl_obj_store(obj, offsets, value): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_store_tensor_descriptor(obj, offsets, value) + else: + return obj.store(offsets, value) + + +@gluon.jit +def tl_obj_load(obj, offsets): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + return tl_load_tensor_descriptor(obj, offsets) + else: + return obj.load(offsets) + + +@gluon.jit +def tl_obj_gather(obj, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0])) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + mbarrier.expect(bar, x_offsets.shape[0] * obj.block_type.nbytes) + tma_blackwell.async_gather(desc, x_offsets, y_offset, bar, alloc) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + # Load from shared memory into a register tensor using a reasonable default layout + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = alloc.load(ret_layout) + return out + else: + return obj.gather(x_offsets, y_offset) + + +@gluon.jit +def tl_obj_scatter(obj, value, x_offsets, y_offset): + if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor): + desc = obj + desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]] + alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value) + fence_async_shared() + x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout( + 0, ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0])) + x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout) + tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc) + tma.store_wait(0) + else: + obj.scatter(value, x_offsets, y_offset) + + +@ttgl._core.builtin +def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option="zero", _semantic=None): + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty) + return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option, _semantic=_semantic) + + +@gluon.jit +def tl_store_tensor_descriptor(desc, offsets, value): + alloc = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + fence_async_shared() + tma.async_copy_shared_to_global(desc, offsets, alloc) + tma.store_wait(0) + alloc._keep_alive() + + +@gluon.jit +def tl_load_tensor_descriptor(desc, offsets): + smem = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + # Issue async copy from global (descriptor) to shared memory and wait for completion + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, offsets, bar, smem) + mbarrier.wait(bar, phase=0) + mbarrier.invalidate(bar) + # Load from shared memory into a register tensor using a reasonable default layout + ret_layout: ttgl.constexpr = default_blocked_layout(desc.block_shape, ttgl.num_warps()) + out = smem.load(ret_layout) + return out + + +@gluon.jit +def tl_arange(start: ttgl.constexpr, stop: ttgl.constexpr = None): + layout: ttgl.constexpr = default_blocked_layout([stop - start], ttgl.num_warps()) + return ttgl.arange(start, stop, layout=layout) + + +@gluon.jit +def tl_full(shape, value, dtype=None): + layout: ttgl.constexpr = default_blocked_layout(shape, ttgl.num_warps()) + return ttgl.full(shape, value, dtype, layout=layout) + + +@ttgl._core.builtin +def tl_trans(value, *dims, _semantic=None): + return value.trans(*dims, _semantic=_semantic) + + +@ttgl._core.builtin +def cat(input, other, can_reorder=False, layout=None, _semantic=None): + """ + Concatenate the two tensors. + + Args: + input (tensor): The first input tensor. + other (tensor): The second input tensor. + can_reorder (bool): Compiler hint. If true, the compiler is allowed to reorder elements while concatenating inputs. Only use if the order does not matter (e.g., result is only used in reduction ops). Current implementation of `cat` supports only can_reorder=True. + layout (DistributedLayout): The destination layout of the output tensor. + + Returns: + tensor: The concatenated tensor. + """ + can_reorder = ttgl._core._unwrap_if_constexpr(can_reorder) + layout = ttgl._core._unwrap_if_constexpr(layout) + return _semantic.cat(input, other, can_reorder, layout) + + +@gluon.jit +def tl_cat(lhs, rhs, can_reorder=False): + return cat(lhs, rhs, can_reorder, layout=default_blocked_layout([lhs.shape[0] + rhs.shape[0]], ttgl.num_warps())) + + +@gluon.jit +def reset_to_default_layout(value): + ty: ttgl.constexpr = value.type + if isinstance(ty, ttgl.tuple_type): + out = () + for i in ttgl.static_range(len(value)): + r = ttgl.convert_layout(value[i], layout=default_blocked_layout(value[i].type.shape, ttgl.num_warps())) + out = out + (r, ) + return out + elif isinstance(value, ttgl.tensor) and isinstance(value.type, ttgl.distributed_type): + layout: ttgl.constexpr = default_blocked_layout(ty.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + else: + return value + + +@gluon.constexpr_function +def get_split_src_layout(shape: ttgl.constexpr, num_warps: ttgl.constexpr) -> ttgl.constexpr: + rank = len(shape) + size_per_thread = [1 if i != rank - 1 else 2 for i in range(rank)] + # Distribute 32 threads per warp across dimensions (simple heuristic: last-fastest) + threads_per_warp = [1 for _ in range(rank)] + remaining_threads = get_num_threads_per_warp() + for dim in range(rank - 2, -1, -1): + threads_per_warp[dim] = min(shape[dim], remaining_threads) + remaining_threads = remaining_threads // threads_per_warp[dim] + # Use provided num_warps to distribute warps per CTA (put all on first dim) + warps_per_cta = [1 for _ in range(rank)] + warps_per_cta[0] = num_warps + # Natural order [rank-1, rank-2, ..., 0] + order = [i for i in range(rank - 1, -1, -1)] + return ttgl.BlockedLayout(size_per_thread=size_per_thread, threads_per_warp=threads_per_warp, + warps_per_cta=warps_per_cta, order=order) + + +@gluon.jit +def set_split_src_layout(value): + layout: ttgl.constexpr = get_split_src_layout(value.type.shape, ttgl.num_warps()) + return ttgl.convert_layout(value, layout=layout) + + +def convert_host_descriptor(desc): + + def torch_dtype_to_triton(dtype): + import torch + if dtype == torch.float8_e5m2: + return ttgl.float8e5 + if dtype == torch.float8_e4m3fn: + return ttgl.float8e4nv + return getattr(ttgl, str(dtype).split('.')[1]) + + from triton.tools.tensor_descriptor import TensorDescriptor + assert isinstance(desc, TensorDescriptor) + block_shape = desc.block_shape + dtype = desc.base.dtype + tensor = desc.base + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(dtype)) + return gluon.nvidia.hopper.TensorDescriptor(tensor, desc.shape, desc.strides, block_shape, layout) + + +# hacks to workaround limited dependencies tracking. +# TODO: fix this by pulling imports into the generated file. +def current_target(): + from triton.runtime import driver + try: + active_driver = driver.active + except RuntimeError: + # If there is no active driver, return None + return None + return active_driver.get_current_target() + + +current_target.__triton_builtin__ = True diff --git a/third_party/mthreads/triton_mthreads.cc b/third_party/mthreads/triton_mthreads.cc new file mode 100644 index 0000000000..671d340473 --- /dev/null +++ b/third_party/mthreads/triton_mthreads.cc @@ -0,0 +1,189 @@ +#include "Dialect/MTGPU/IR/Dialect.h" +#include "Dialect/MUSA/IR/Dialect.h" +#include "MTGPUToLLVM/Passes.h" +#include "TritonMUSAGPUToLLVM/Passes.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "passes.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include +#include +#include + +namespace py = pybind11; + +namespace { + +llvm::Function *findPrimaryKernel(llvm::Module &module, + llvm::StringRef kernelNameHint) { + if (!kernelNameHint.empty()) { + if (llvm::Function *fn = module.getFunction(kernelNameHint)) { + if (!fn->isDeclaration()) + return fn; + } + } + for (llvm::Function &fn : module) { + if (!fn.isDeclaration() && + fn.getLinkage() == llvm::GlobalValue::ExternalLinkage) + return &fn; + } + for (llvm::Function &fn : module) { + if (!fn.isDeclaration()) + return &fn; + } + return nullptr; +} + +bool hasMusaAnnotation(llvm::NamedMDNode *annotations, const llvm::Function &fn, + llvm::StringRef key) { + if (!annotations) + return false; + for (llvm::MDNode *node : annotations->operands()) { + if (!node || node->getNumOperands() < 3) + continue; + auto *valueMD = llvm::dyn_cast(node->getOperand(0)); + auto *keyMD = llvm::dyn_cast(node->getOperand(1)); + if (!valueMD || !keyMD) + continue; + auto *annotatedFn = llvm::dyn_cast(valueMD->getValue()); + if (annotatedFn != &fn) + continue; + if (keyMD->getString() == key) + return true; + } + return false; +} + +void addMusaAnnotation(llvm::Module &module, llvm::Function &fn, + llvm::StringRef key, int32_t value) { + llvm::NamedMDNode *annotations = + module.getOrInsertNamedMetadata("musa.annotations"); + if (hasMusaAnnotation(annotations, fn, key)) + return; + + llvm::LLVMContext &ctx = module.getContext(); + llvm::MDNode *node = llvm::MDNode::get( + ctx, {llvm::ValueAsMetadata::get(&fn), llvm::MDString::get(ctx, key), + llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), value))}); + annotations->addOperand(node); +} + +bool moduleUsesMulhiHelper(const llvm::Module &module) { + for (const llvm::Function &fn : module) { + if (fn.isDeclaration()) + continue; + for (const llvm::BasicBlock &block : fn) { + for (const llvm::Instruction &inst : block) { + auto *call = llvm::dyn_cast(&inst); + if (!call) + continue; + const llvm::Function *callee = call->getCalledFunction(); + if (!callee) + continue; + llvm::StringRef calleeName = callee->getName(); + if (calleeName == "__mt_umulhi" || calleeName == "__mt_umul64hi") + return true; + } + } + } + return false; +} + +} // namespace + +void init_triton_musa_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton; + m.def("add_mtgpu_to_llvm", [](mlir::PassManager &pm, int32_t capability) { + pm.addPass(mlir::triton::createConvertMTGPUToLLVMPass(capability)); + }); + m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability) { + pm.addPass(mlir::triton::createConvertTritonMUSAGPUToLLVMPass(capability)); + }); + m.def("add_allocate_shared_memory", [](mlir::PassManager &pm, + int32_t capability) { + pm.addPass(mlir::triton::createAllocateMUSASharedMemoryPass(capability)); + }); + ADD_PASS_OPTION_WRAPPER_2("add_pipeline", mlir::createTritonMUSAGPUPipeline, + int, bool); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", + mlir::createTritonMUSAGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0( + "add_canonicalize_sqmma_result_conversions", + mlir::createTritonMUSAGPUCanonicalizeSqmmaResultConversions); + ADD_PASS_WRAPPER_0("add_convert_sqmma_to_mtgpu", + mlir::createTritonMUSAGPUConvertSqmmaToMTGPU); + ADD_PASS_WRAPPER_0("add_finalize_barriers", + mlir::createTritonMUSAGPUFinalizeBarriers); + ADD_PASS_WRAPPER_0("add_issue_barrier_insertion", + mlir::createTritonMUSAGPUIssueBarrierInsertion); + ADD_PASS_WRAPPER_0("add_mark_inplace_loads", + mlir::createTritonMUSAGPUMarkInplaceLoads); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + mlir::createTritonMUSAGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_optimize_dot_operands", + mlir::createTritonMUSAGPUOptimizeDotOperands); + ADD_PASS_WRAPPER_0("add_tme_lowering", mlir::createTritonMUSAGPUTMELowering); + ADD_PASS_WRAPPER_0("add_optimize_descriptor_encoding", + mlir::createTritonMUSAGPUOptimizeDescriptorEncoding); + ADD_PASS_WRAPPER_0("add_optimize_sqmma_accumulator_layout", + mlir::createTritonMUSAGPUOptimizeSqmmaAccumulatorLayout); +} + +void init_triton_mthreads(py::module &&m) { + auto passes = m.def_submodule("passes"); + init_triton_musa_passes_ttgpuir(passes.def_submodule("ttgpuir")); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry + .insert(); + mlir::registerLLVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + m.def("attach_datalayout", [](llvm::Module &module) { + const std::string dataLayout = "e-p:64:64:64:64-" + "p1:64:64:64:64-" + "p2:64:64:64:64-" + "p3:32:32-" + "p4:32:32-" + "p5:64:64-" + "i64:64-" + "v16:16-" + "v24:32-" + "v32:32-" + "v48:64-" + "v96:128"; + module.setDataLayout(dataLayout); + }); + + m.def("decorate_kernel_abi", + [](llvm::Module &module, const std::string &kernelNameHint, + int32_t maxntidx) -> std::string { + llvm::Function *kernel = findPrimaryKernel(module, kernelNameHint); + if (!kernel) + return ""; + + kernel->setCallingConv(llvm::CallingConv::MTGPU_KERNEL); + addMusaAnnotation(module, *kernel, "kernel", 1); + addMusaAnnotation(module, *kernel, "maxntidx", + std::max(1, maxntidx)); + return kernel->getName().str(); + }); + + m.def("module_uses_mulhi_helper", + [](llvm::Module &module) { return moduleUsesMulhiHelper(module); }); +}