diff --git a/.codex/skills/generate-vpto-release-doc/SKILL.md b/.codex/skills/generate-vpto-release-doc/SKILL.md new file mode 100644 index 000000000..fde4ceb97 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/SKILL.md @@ -0,0 +1,86 @@ +--- +name: generate-vpto-release-doc +description: Generate or refresh `docs/release/vpto-spec-v*.md` by merging `docs/vpto-spec.md` with `docs/isa/*.md`, following the release-doc naming and layout. Use when the user asks to create or update a merged VPTO release spec, inline ISA Markdown into one release document, add TOC and version bullets, move `Quick Reference by Category` to the end, or strip update, appendix, and correspondence content from the merged release doc. +--- + +# Generate VPTO Release Doc + +Use this skill when the task is specifically about: +- creating a new merged release document under `docs/release/` +- refreshing an existing `vpto-spec-v*.md` release doc from `docs/vpto-spec.md` and `docs/isa/*.md` +- keeping the merged release doc aligned with the naming and structure used in `docs/release/vpto-spec-v0.1.md` + +## Canonical Workflow + +1. Pick the target version and output path. + +Default output path: + +```bash +docs/release/vpto-spec-v.md +``` + +2. Run the bundled generator script. + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.2 +``` + +If you need an explicit note for the new version bullet: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py \ + --version 0.2 \ + --version-note 'Merge `docs/vpto-spec.md` with `docs/isa/*.md`; add TOC; move `Quick Reference by Category` to the end; remove update, appendix, and correspondence content' +``` + +3. Review the generated file before finalizing. + +Check these invariants: +- exactly one `#` level title in the whole file +- `[toc]` is present near the top +- the top version bullet for the requested version was added +- `## Quick Reference by Category` is the final top-level section +- no `Updated:` / review-status boilerplate remains at the beginning +- no appendix sections remain +- no `## Correspondence Categories` section remains +- no `CCE correspondence` / builtin-mapping blocks remain + +4. If the user wants extra release-note wording, patch only the version bullets or other small wording around the generated content. Prefer rerunning the script over hand-merging large sections. + +## Source Mapping + +Use `docs/vpto-spec.md` for: +- `Part I: Architecture Overview` +- `Part II: Notation Convention` +- `C-Style Semantics Convention` +- `Template Placeholder Conventions` +- `Instruction Groups` +- `Supported Data Types` +- `Common Patterns` +- `Quick Reference by Category` + +Use `docs/isa/*.md` for: +- the inlined `Detailed ISA Group Reference` + +## Merge Rules + +The merged release document should: +- keep the release-doc title and version-bullet style +- preserve the `Instruction Groups` summary table +- inline `docs/isa/*.md` under `Detailed ISA Group Reference` +- convert `docs/isa/*.md` links into in-document anchors like `#isa-03-vector-load-store` +- demote the inlined ISA headings by two levels so the merged TOC stays stable +- place `Quick Reference by Category` at the end + +The merged release document must remove: +- beginning-of-file update/review metadata from `docs/vpto-spec.md` +- `## Correspondence Categories` +- all `CCE correspondence` blocks and related builtin/token mapping lines +- the sentence `For detailed semantics, C-style pseudocode, and CCE mappings, see the individual group documentation files.` +- appendix sections + +## Notes + +- The script assumes the source headings in `docs/vpto-spec.md` keep their current names. If extraction fails, inspect the heading names there before patching the script. +- The script is deterministic and is the preferred path for regenerating large merged release docs. diff --git a/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py new file mode 100644 index 000000000..f03e7d592 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Generate merged VPTO release spec from docs/vpto-spec.md and docs/isa/*.md.""" + +from __future__ import annotations + +import argparse +import re +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[4] +DOCS_DIR = ROOT / "docs" +SOURCE_SPEC = DOCS_DIR / "vpto-spec.md" +ISA_DIR = DOCS_DIR / "isa" +RELEASE_DIR = DOCS_DIR / "release" + +TITLE = "# PTO micro Instruction Spec \u2014 Draft (A5)" +DEFAULT_VERSION_NOTES = { + "0.1": "Doc Init", + "0.2": "Update micro Instruction latency and throughput", + "0.3": "Refresh VPTO ISA specification", +} + +KEEP_SECTIONS = [ + "## Part I: Architecture Overview", + "## Part II: Notation Convention", + "## Instruction Groups", + "## Supported Data Types", + "## Common Patterns", + "## Quick Reference by Category", +] + +ISA_LINK_RE = re.compile(r"\[([^\]]+)\]\((?:\.\./)?(?:isa/)?([0-9]{2}-[A-Za-z0-9-]+)\.md\)") + + +def extract_sections(markdown: str) -> dict[str, str]: + headings = list(re.finditer(r"^## .*$", markdown, flags=re.MULTILINE)) + sections: dict[str, str] = {} + for index, match in enumerate(headings): + heading = match.group(0).strip() + start = match.start() + end = headings[index + 1].start() if index + 1 < len(headings) else len(markdown) + sections[heading] = markdown[start:end].strip() + "\n" + return sections + + +def rewrite_isa_links(text: str) -> str: + return ISA_LINK_RE.sub(lambda m: f"[{m.group(1)}](#isa-{m.group(2).lower()})", text) + + +def trim_trailing_rule(text: str) -> str: + return re.sub(r"\n---\s*\Z", "\n", text.strip() + "\n").rstrip() + + +def strip_unwanted_lines(text: str) -> str: + lines = text.splitlines() + kept: list[str] = [] + skip_correspondence = False + for line in lines: + if re.match(r"^## Correspondence Categories\b", line): + skip_correspondence = True + continue + if skip_correspondence: + if re.match(r"^## ", line): + skip_correspondence = False + else: + continue + if line.startswith("> **Status:**") or line.startswith("> **Base:**") or line.startswith("> **Additions from:**") or line.startswith("> **Updated:**"): + continue + if "For detailed semantics, C-style pseudocode, and CCE mappings" in line: + continue + if "CCE correspondence" in line or "builtin mapping" in line.lower(): + continue + kept.append(line) + text = "\n".join(kept).strip() + "\n" + text = re.sub(r"\n## Appendix [A-Z]:.*\Z", "\n", text, flags=re.DOTALL) + return text + + +def demote_headings(text: str, levels: int = 2) -> str: + def replace(match: re.Match[str]) -> str: + hashes = match.group(1) + heading = match.group(2) + new_level = min(6, len(hashes) + levels) + return f"{'#' * new_level} {heading}" + + return re.sub(r"^(#{1,6})\s+(.*)$", replace, text, flags=re.MULTILINE) + + +def render_version_bullets(version: str, version_note: str | None) -> str: + notes = dict(DEFAULT_VERSION_NOTES) + if version_note: + notes[version] = version_note + elif version not in notes: + notes[version] = "Release refresh" + + def key_fn(item: str) -> tuple[int, ...]: + return tuple(int(part) for part in item.split(".")) + + lines = [f"- v{ver}: {notes[ver]}" for ver in sorted(notes, key=key_fn, reverse=True)] + return "\n".join(lines) + + +def build_release_doc(version: str, version_note: str | None) -> str: + source_text = strip_unwanted_lines(SOURCE_SPEC.read_text()) + sections = extract_sections(source_text) + + missing = [name for name in KEEP_SECTIONS if name not in sections] + if missing: + raise SystemExit(f"missing expected headings in docs/vpto-spec.md: {missing}") + + content_sections = [trim_trailing_rule(rewrite_isa_links(sections[name])) for name in KEEP_SECTIONS[:-1]] + + isa_blocks: list[str] = ["## Detailed ISA Group Reference"] + for isa_path in sorted(ISA_DIR.glob("*.md")): + isa_text = rewrite_isa_links(isa_path.read_text().strip() + "\n") + isa_blocks.append(trim_trailing_rule(demote_headings(isa_text))) + + quick_reference = trim_trailing_rule(rewrite_isa_links(sections["## Quick Reference by Category"])) + + parts = [ + TITLE, + "", + render_version_bullets(version, version_note), + "", + "[toc]", + "", + "---", + "", + "\n\n".join(content_sections), + "\n\n".join(isa_blocks), + quick_reference, + "", + ] + return "\n".join(part for part in parts if part is not None) + + +def validate_release_doc(text: str) -> None: + if text.count("# PTO micro Instruction Spec") != 1: + raise SystemExit("expected exactly one top-level title") + if "\n[toc]\n" not in text: + raise SystemExit("missing [toc] near top") + if re.search(r"^## Quick Reference by Category\b", text, flags=re.MULTILINE) is None: + raise SystemExit("missing Quick Reference by Category") + if re.search(r"^## Quick Reference by Category\b[\s\S]*\Z", text, flags=re.MULTILINE) is None: + raise SystemExit("Quick Reference by Category must be present at end") + if re.search(r"^## Appendix\b", text, flags=re.MULTILINE): + raise SystemExit("appendix content must not remain") + if "Updated:" in text or "review" in text.splitlines()[:8]: + raise SystemExit("beginning metadata must not remain") + if "## Correspondence Categories" in text or "CCE correspondence" in text: + raise SystemExit("correspondence content must not remain") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--version", required=True, help="Release version, e.g. 0.2") + parser.add_argument("--version-note", help="Version bullet text for the requested version") + parser.add_argument("--output", help="Explicit output path") + args = parser.parse_args() + + output = Path(args.output) if args.output else RELEASE_DIR / f"vpto-spec-v{args.version}.md" + output.parent.mkdir(parents=True, exist_ok=True) + + text = build_release_doc(args.version, args.version_note) + validate_release_doc(text) + output.write_text(text) + + +if __name__ == "__main__": + main() diff --git a/.codex/skills/llvm-test-tool-fallback/SKILL.md b/.codex/skills/llvm-test-tool-fallback/SKILL.md new file mode 100644 index 000000000..e71bf178d --- /dev/null +++ b/.codex/skills/llvm-test-tool-fallback/SKILL.md @@ -0,0 +1,16 @@ +--- +name: llvm-test-tool-fallback +description: When `lit` or `FileCheck` is missing from the current shell, look for the corresponding LLVM test tools in the environment or existing LLVM workspace before treating it as a repo issue. +--- + +# LLVM Test Tool Fallback + +Use this skill when: +- `python3 -m lit` fails because `lit` is missing +- `FileCheck` is not in `PATH` +- a test command fails only because LLVM test tools are not available in the current shell + +Rule: +- do not stop at `command not found` +- first try to find `lit` / `FileCheck` from the environment's LLVM toolchain or an existing LLVM workspace +- treat missing `lit` / `FileCheck` as an environment-tool issue, not as a PTOAS regression diff --git a/.codex/skills/pto-a5-installed-impl-trace/SKILL.md b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md new file mode 100644 index 000000000..b67cb0128 --- /dev/null +++ b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md @@ -0,0 +1,198 @@ +--- +name: pto-a5-installed-impl-trace +description: Guide LLVM IR discovery for A5 VPTO lowering from the installed CANN/PTO implementation under ASCEND_HOME_PATH. Use when the user does not yet know which `llvm.hivm.*` intrinsic, builtin wrapper, or operand contract a VPTO/A5 op should lower to. +--- + +# PTO A5 Installed Implementation Trace + +Use this skill when the task is specifically about: +- checking what an A5 PTO op really does on the installed machine +- mapping PTO/A5 behavior to builtins or LLVM/HIVM intrinsics +- tracing PTO wrappers down to CCE builtin wrappers such as `__builtin_cce_*` +- deciding whether repo-local lowering is correct or only a guess +- resolving conflicts between generated repo IR and installed PTO headers +- tracing `Cmp`, `Cmps`, predicate, pack, store, or typed vector behavior + +This skill answers: +- what LLVM IR a VPTO op should lower to +- what the authoritative intrinsic name is +- what operand list or mask form the installed toolchain expects +- whether repo-local lowering or emission diverges from installed behavior + +This skill does not answer: +- how to build or link a finished LLVM-path artifact end to end +- how to package `.o`, `fatobj`, or `.so` +- how to run board validation + +## Strong Rule + +If you are about to change repo code for an A5 op, stop and inspect the +installed PTO implementation first. Treat the installed PTO library under +`ASCEND_HOME_PATH` as the semantic source of truth. + +Only make a repo-local substitution after you have confirmed one of: +- the installed PTO headers already express that replacement relationship +- the frontend/compiler intrinsic contract proves two forms are equivalent at + the intrinsic layer + +Do not guess behavior from repo-local lowering, emitter code, or from what +"seems plausible" for an intrinsic sequence. + +Do not start from repo-local lowering when the question is about real A5 +behavior. The installed PTO implementation under `ASCEND_HOME_PATH` is the +first source of truth. + +## Required Search Order + +Always follow this order: + +1. `source /usr/local/Ascend/cann/set_env.sh` +2. confirm `ASCEND_HOME_PATH` +3. inspect installed PTO dispatch headers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/common/pto_instr_impl.hpp` +4. inspect the matching A5 implementation: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/T*.hpp` +5. inspect typed helpers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/utils.hpp` +6. inspect builtin wrapper headers when the question is about the real compiler-facing builtin: + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/__clang_cce_vector_intrinsics.h` + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/npu_arch_*/__clang_cce_vector_intrinsics.h` +7. inspect intrinsic name availability directly from the installed compiler binary before guessing LLVM/HIVM spellings: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` + - narrow to the op under investigation, for example: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.(vneg|vrsqrt|vnot|vmov)'` +8. only then compare against repo-local code under `lib/PTO/Transforms/` + +## Practical Fast Path + +For VPTO LLVM emission work, prefer this concrete order instead of jumping +straight to ad hoc compiler probes: + +1. confirm the op exists in installed PTO/A5 headers +2. confirm the builtin wrapper shape in installed Clang headers +3. confirm the intrinsic name family with: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` +4. patch repo-local emitter/lowering as little as possible +5. generate real repo-driven LLVM IR through the existing VPTO validation path: + - `source scripts/ptoas_env.sh` + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +6. inspect: + - `//*.ll` + - `//validation.log` +7. only after seeing the real generated `.ll` and Bisheng failure should you + refine the call shape + +This route is preferred because it preserves the real PTOAS lowering context, +the real case structure, and the exact driver invocation used by the repo. + +## Probe Strategy + +Use probes in this order: + +1. installed headers +2. `strings bisheng` +3. repo-generated VPTO LLVM IR from `run_host_vpto_validation.sh` +4. only then minimal handwritten `.ll` probes +5. handwritten `.cce` frontend probes are last resort + +Handwritten `.ll` probes are acceptable for quick ABI sanity checks such as: +- whether Bisheng recognizes a specific `llvm.hivm.*` name +- whether a guessed argument count immediately crashes or verifies + +But they are not the primary source of truth for semantic or frontend wrapper +behavior. + +## Avoid These Traps + +Do not default to handwritten `.cce` probes when repo-driven IR is available. +On this machine, bare `.cce` probes often fail before reaching the real +question because they are missing the exact frontend driver mode, target +features, wrapper setup, or host/device compilation context used by the repo. + +In particular, treat these as warning signs that you have started too low in +the stack: +- errors around `[aicore]` +- errors around `__cce_half` +- builtin alias attribute failures +- missing target feature or wrapper environment failures + +When these happen, step back to the repo-driven compile-only flow instead of +trying to repair the ad hoc frontend invocation from scratch. + +## Trace By The Real Type Split + +Do not infer the active implementation from the final storage type alone. +Follow the source element type and the installed dispatch branch. + +Example: +- for `Cmp` with `f32 -> ui8`, inspect the `sizeof(src) == 4` branch, not the + `ui8` destination branch +- for scalar or packed outputs, treat pack/store ops separately from compare + predicate generation + +Typical A5 compare split: +- 32-bit source elements -> `TCmp_32B` / `TCmps_32B` +- 16-bit source elements -> 16-bit branch +- 8-bit source elements -> 8-bit branch + +## What To Extract + +When tracing an op, capture: +- the installed PTO entrypoint that handles it +- the exact typed branch that matches the user case +- the builtins used in order +- any typed helper that explains `pset/plt` or store packing selection +- the compiler builtin wrapper if it is visible in installed Clang headers + +For compare-family questions, separate: +- predicate generation +- compare builtin +- predicate pack/interleave +- predicate store + +Stop at the builtin wrapper layer if the lower compiler implementation is not +available. That is still enough to answer questions such as: +- `pset_b32 -> __builtin_cce_pset_b32` +- `plt_b32 -> __builtin_cce_plt_b32_v300` + +## When The Builtin Name Is Still Not Enough + +If the installed PTO headers tell you the wrapper builtin but that still does +not answer the LLVM/HIVM operand contract, do not guess from repo-local +lowering. Extend the trace using the generated repo testcase first, and only +after that the real compiler frontend: + +1. run an existing repo case with: + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +2. inspect the generated `.ll` and `validation.log` +3. if the repo-generated LLVM IR still leaves the contract ambiguous, inspect + the testcase build flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun the same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` for the exact installed PTO wrapper call sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survived +6. if needed, rerun the same frontend compile with `-S`, `-emit-llvm`, or the + equivalent `cc1` invocation from `-v` to inspect the real LLVM IR emitted by + the compiler frontend before instruction selection + +This is the required fallback when the question is really: +- what exact `llvm.hivm.*` intrinsic shape the compiler expects +- whether a hand-written LLVM IR call shape is valid +- whether a selector failure is caused by a guessed mask/value form + +Prefer this real-frontend route over inventing mask constants or argument +shapes from memory. + +## Reporting Back + +When you use this skill, report: +- the exact installed header paths inspected +- whether `strings $ASCEND_HOME_PATH/bin/bisheng` confirmed the intrinsic name +- which typed branch was the authoritative one +- the builtin sequence observed there +- the builtin wrapper name if you found one in the installed Clang headers +- whether repo-generated `.ll` matched the guessed call shape +- whether repo-local lowering matches or diverges +- the first concrete mismatch, if any diff --git a/.codex/skills/pto-gym-vpto-validation/SKILL.md b/.codex/skills/pto-gym-vpto-validation/SKILL.md new file mode 100644 index 000000000..0e1451a61 --- /dev/null +++ b/.codex/skills/pto-gym-vpto-validation/SKILL.md @@ -0,0 +1,85 @@ +--- +name: pto-gym-vpto-validation +description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks to run PTO-Gym SIM or board validation from the current source tree. Always force PTOAS onto the VPTO LLVM path instead of relying on the repo default backend. +--- + +# PTO-Gym VPTO Validation + +Use this skill when the task is specifically about: +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh` +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh` +- validating PTO-Gym cases from this PTOAS source tree + +## Required Rule + +When PTO-Gym is run from this repo, do not rely on the default PTOAS backend. + +Always pass PTOAS flags that force the VPTO LLVM path. +The current `ptoas` CLI spellings in this repo are `--pto-backend=vpto` and +`--vpto-emit-hivm-llvm`; do not shorten `--pto-backend` to `--backend`. + +Use: + +```bash +PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +If the caller already provides `PTOAS_FLAGS`, make sure these options are still +present. Do not silently fall back to the repo default backend. + +## Canonical Environment + +Use `.work/` under the repo for all scratch output and temp files: + +```bash +mkdir -p .work/tmp .work/runs +export TMPDIR=$PWD/.work/tmp +export TMP=$TMPDIR +export TEMP=$TMPDIR +``` + +Typical simulator environment: + +```bash +source /home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2/set_env.sh +export ASCEND_HOME_PATH=/home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2 +export PTOAS_BIN=$PWD/build/tools/ptoas/ptoas +export PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +## Canonical Commands + +Single case: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-single \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_NAME=micro-op/binary-vector/vadd \ +DEVICE=SIM \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh +``` + +Parallel micro-op sweep: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-microop \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_PREFIX=micro-op \ +DEVICE=SIM \ +JOBS=64 \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh +``` + +## Reporting Back + +Report: +- the exact `PTOAS_FLAGS` used +- the final `PASS/FAIL` counts +- the summary file path under `.work/runs/...` + +If a run fails, identify the first failing case from `parallel-summary.tsv` and +then inspect that case directory under `WORK_SPACE`. diff --git a/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md new file mode 100644 index 000000000..b50bb36fc --- /dev/null +++ b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md @@ -0,0 +1,37 @@ +--- +name: ptoas-bisheng-asm-from-object-cmd +description: Use when you need assembly for a PTOAS VPTO case that already compiles to a device object. First find the exact command that produced the `.o`, then derive the `.s` command by replacing `-c` with `-S`. Do not guess a fresh Bisheng command line. +metadata: + short-description: Derive `.s` from real `.o` command +--- + +# PTOAS Bisheng ASM From Object Command + +Use this skill when the task is to inspect generated assembly for a VPTO case and the case already has a known `.o` build path. + +## Rule + +- Do not invent a new `bisheng` command. +- First find the exact command that built the `.o`. +- Then derive the `.s` command from that exact command by changing `-c` to `-S`. +- Keep the rest of the arguments unchanged unless the original command already wrote to a conflicting output path. + +## Preferred Sources + +- Validation script logs +- Build scripts such as `test/vpto/scripts/run_host_vpto_validation.sh` +- Saved shell history or generated compile traces in the case workspace + +## Procedure + +1. Locate the real `.o` compile command for the target case. +2. Copy that command exactly. +3. Replace `-c` with `-S`. +4. Point `-o` to a `.s` path. +5. Run the derived command. +6. Inspect the generated assembly instead of guessing from LLVM IR. + +## Anti-Pattern + +- Do not hand-write a new `bisheng -S ...` command from memory. +- Do not drop flags such as `--target`, `-march`, `--cce-aicore-arch`, `--cce-aicore-only`, `-O2`, include paths, or wrapper options that were present in the real `.o` command. diff --git a/.codex/skills/ptoas-build-and-abs/SKILL.md b/.codex/skills/ptoas-build-and-abs/SKILL.md new file mode 100644 index 000000000..bbfd993e3 --- /dev/null +++ b/.codex/skills/ptoas-build-and-abs/SKILL.md @@ -0,0 +1,101 @@ +--- +name: ptoas-build-and-abs +description: Rebuild PTOAS in the repo build directory and compile the Abs sample to inspect generated VPTO output. Use when the user asks to build ptoas, rebuild the current build tree, or run/check the Abs sample output. +--- + +# PTOAS Build And Abs + +Use this skill when the task is specifically about: +- rebuilding `ptoas` in this repo +- doing a full repo build in the repo-local `build/` directory +- compiling `test/samples/Abs/abs.py` +- inspecting the generated VPTO text for `Abs` + +## Canonical Commands + +### 1. Configure the repo-local build directory + +`do_cmake.sh` is the canonical entrypoint. It always targets `./build`. + +```bash +./do_cmake.sh --llvm /data/mouliangyu/projects/github.com/llvm/llvm-project/install +``` + +If `do_cmake.sh` fails because `build/` has a generator mismatch between old Makefiles/Ninja metadata, do not guess. State that `build/` is inconsistent and ask before cleaning the generated build metadata in `build/`. + +### 2. Build + +For just the CLI: + +```bash +CCACHE_DISABLE=1 ninja -C build ptoas +``` + +For a full repo build: + +```bash +CCACHE_DISABLE=1 ninja -C build +``` + +If the user asked for "full build", prefer the full command above. If they only want to run `Abs`, building `ptoas` is usually enough. + +### 3. Prepare runtime environment + +Before running `runop.sh`, always: + +```bash +source env.sh +``` + +This sets `PYTHONPATH`, `LD_LIBRARY_PATH`, and the MLIR/PTO python roots needed by the samples. + +### 4. Compile `Abs` to VPTO text + +Use `runop.sh` with explicit `PTOAS_BIN`, explicit output directory, and A5 backend flags: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-vpto \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Expected outputs: +- `/tmp/ptoas-abs-vpto/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-abs-vpto/Abs/abs-pto.cpp` + +Despite the `.cpp` suffix, on the VPTO backend this file contains the emitted VPTO textual IR. + +## Inspection + +The main file to show the user is: + +```bash +sed -n '1,260p' /tmp/ptoas-abs-vpto/Abs/abs-pto.cpp +``` + +For quick sanity checks, look for: +- `vpto.copy_gm_to_ubuf` +- `src_strides = [32, 1]` +- `trace_offsets = [0, 0]` +- `trace_sizes = [32, 32]` +- `cce_aiv_loop_hint` +- `llvm.loop.aivector_scope` +- `vpto.vlds` +- `vpto.vabs` +- `vpto.vsts` +- `vpto.copy_ubuf_to_gm` + +## Reporting Back + +When you ran `Abs`, report: +- whether `ptoas` had to be rebuilt +- the exact generated file path for the VPTO text +- whether the output contains the expected copy-family metadata and vec-scope carrier attrs + +If the build fails, include the first concrete blocker: +- generator mismatch in `build/` +- link failure in `ptoas` +- missing runtime env because `env.sh` was not sourced +- missing sample output file diff --git a/.codex/skills/ptoas-npu-validation-a5/SKILL.md b/.codex/skills/ptoas-npu-validation-a5/SKILL.md new file mode 100644 index 000000000..735cde327 --- /dev/null +++ b/.codex/skills/ptoas-npu-validation-a5/SKILL.md @@ -0,0 +1,335 @@ +--- +name: ptoas-npu-validation-a5 +description: Generate and run PTOAS-based A5 test/npu_validation or test/vpto validations, build the testcase binaries, and validate runtime output on NPU or simulator. Use when the user wants NPU run validation, golden/compare checks, or runtime troubleshooting for A5. +--- + +# PTOAS NPU Validation A5 + +Use this skill when the task is specifically about: +- generating `test/npu_validation` projects from PTOAS output +- running `test/vpto/scripts/run_host_vpto_validation.sh` +- running `test/vpto` board validation or simulator validation +- building testcase binaries for A5 +- running NPU or simulator validation +- generating golden inputs and checking results with `compare.py` +- diagnosing runtime blockers such as missing device access or `aclrtSetDevice` + +This skill is the main entry point for runtime validation. + +Do not use this skill as the primary entry point when the task is only about: +- exporting LLVM IR or LLVM bitcode +- validating the `bisheng` handoff +- assembling a fat object or replacement kernel library from the LLVM path + +When this validation flow needs a custom LLVM IR or LLVM BC artifact, use +`ptoas-vpto-llvm-artifacts` first to build that artifact, then return here to +run the testcase. + +## Important Constraint + +The `npu_validation` flow still depends on an EmitC-generated sample export to +materialize the host-side testcase skeleton. + +For the existing automation, this EmitC export step is not something the user +must run manually first. The provided host-validation scripts already do it for +you. + +Specifically: +- `run_host_npu_validation.sh` automatically invokes `test/samples/runop.sh` + first +- that export is written under `WORK_SPACE/emitc/...` +- `run_host_npu_validation_case.sh` then uses that generated EmitC `*-pto.cpp` + as the input to `generate_testcase.py` + +Even when the final kernel under validation comes from the VPTO/LLVM path, the +current scripts do not generate a standalone host runner from VPTO MLIR or +LLVM IR directly. The canonical automated flow is: + +1. `run_host_npu_validation.sh` automatically exports the sample through the + default EmitC path to get `*-pto.cpp` +2. `run_host_npu_validation_case.sh` runs `generate_testcase.py` on that + generated EmitC kernel to create the testcase directory, host `main.cpp`, + kernel wrapper source, `launch.cpp`, and build system +3. if LLVM/VPTO validation is desired, `run_host_npu_validation_case.sh` + optionally calls `build_llvm_ir_kernel_so.sh` to rebuild and replace only + the final `lib_kernel.so` +4. the generated testcase binary is then run against that replacement kernel + library + +In other words: +- the scripts automatically do the EmitC export step before testcase + generation +- EmitC is still required to produce the host/testcase scaffolding +- LLVM/VPTO replaces the device kernel library, not the host testcase +- feeding raw VPTO textual MLIR directly into `generate_testcase.py` is not a + supported path + +## Automation Entry Points + +Use these scripts as the default automation entry points instead of rebuilding +the flow by hand: + +- `test/vpto/scripts/run_host_vpto_validation.sh` + - top-level driver for curated VPTO `kernel.pto` board/simulator validation + - consumes hand-authored VPTO cases under `test/vpto/cases/...` + - handles lowering, LLVM-path device object build, host build, golden, and compare + - is the default entry point when the user asks to run VPTO board validation directly + - when it fails at runtime, follow this skill's troubleshooting guidance instead of treating the first `aclrtSetDevice` failure as a final product regression + +- `test/npu_validation/scripts/run_host_npu_validation.sh` + - top-level driver for host/NPU validation + - automatically runs `test/samples/runop.sh` first + - automatically writes the EmitC export under `WORK_SPACE/emitc/...` + - discovers testcase names from `test/samples//npu_validation/...` + - dispatches each testcase to `run_host_npu_validation_case.sh` + +- `test/npu_validation/scripts/run_host_npu_validation_case.sh` + - per-testcase execution driver + - consumes the already-generated EmitC kernel from `WORK_SPACE/emitc/...` + - runs `generate_testcase.py` + - configures and builds the testcase + - when `KERNEL_MODE=llvm`, calls `build_llvm_ir_kernel_so.sh` to replace the + device kernel shared library + - runs the testcase binary and then `compare.py` + +- `test/npu_validation/scripts/build_llvm_ir_kernel_so.sh` + - helper used by the case runner for LLVM/VPTO validation + - assumes the EmitC-derived testcase and host wrapper already exist + - rebuilds only the replacement `lib_kernel.so` + - its internal `runop.sh` export may return non-zero because another sample + in the same family failed, but the script intentionally continues if the + requested testcase's LLVM IR artifact was still produced + +## Preconditions + +Before running `npu_validation` or `test/vpto`, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is in `PATH` or available through CANN `set_env.sh` +- `PTO_ISA_ROOT` points to a `pto-isa` checkout with: + - `include/` + - `tests/common/` +- the shell can read `/dev/davinci*` if you intend to execute on real hardware + +Example: + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +``` + +Useful runtime check: + +```bash +source /usr/local/Ascend/cann/set_env.sh +python3 - <<'PY' +import ctypes +lib = ctypes.cdll.LoadLibrary('libascendcl.so') +aclInit = lib.aclInit; aclInit.argtypes=[ctypes.c_char_p]; aclInit.restype=ctypes.c_int +aclrtGetDeviceCount = lib.aclrtGetDeviceCount; aclrtGetDeviceCount.argtypes=[ctypes.c_void_p]; aclrtGetDeviceCount.restype=ctypes.c_int +aclrtSetDevice = lib.aclrtSetDevice; aclrtSetDevice.argtypes=[ctypes.c_int]; aclrtSetDevice.restype=ctypes.c_int +cnt = ctypes.c_uint(0) +print('aclInit', aclInit(None)) +print('aclrtGetDeviceCount', aclrtGetDeviceCount(ctypes.byref(cnt)), cnt.value) +print('aclrtSetDevice', aclrtSetDevice(0)) +PY +``` + +Interpretation: +- `aclInit` succeeds +- `aclrtGetDeviceCount` should report at least one device if the runtime can enumerate hardware +- if `aclrtSetDevice(0)` fails with `507033` (`ACL_ERROR_RT_DEV_SETUP_ERROR`), the user context can see a device but cannot open a usable runtime context + +This interpretation applies equally to: + +- `test/npu_validation` +- `test/vpto` + +When `test/vpto/scripts/run_host_vpto_validation.sh` hits `aclrtSetDevice`, do not immediately report a testcase regression. First treat it as a runtime-environment blocker and follow the checks in this skill. + +## Canonical Flow + +### 1. Generate the PTOAS kernel + +Use the default EmitC-style output, because `npu_validation` consumes `*-pto.cpp`. + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-emitc \ +./test/samples/runop.sh -t Abs +``` + +Expected output: +- `/tmp/ptoas-abs-emitc/Abs/abs-pto.cpp` +- this EmitC kernel is also the required host/testcase input for the later + LLVM/VPTO replacement flow + +### 2. Generate the `npu_validation` testcase + +```bash +python3 test/npu_validation/scripts/generate_testcase.py \ + --input /tmp/ptoas-abs-emitc/Abs/abs-pto.cpp \ + --testcase abs \ + --output-root /tmp/ptoas-npu-validation-run \ + --run-mode sim \ + --soc-version dav_3102 \ + --aicore-arch dav-c310-vec +``` + +Expected output directory: +- `/tmp/ptoas-npu-validation-run/Abs/abs` + +### 3. Configure and build + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cmake -S /tmp/ptoas-npu-validation-run/Abs/abs \ + -B /tmp/ptoas-npu-validation-run/Abs/abs/build \ + -DSOC_VERSION=dav_3102 \ + -DENABLE_SIM_GOLDEN=ON +cmake --build /tmp/ptoas-npu-validation-run/Abs/abs/build --parallel +``` + +Typical build expectations: +- `libabs_kernel.so` builds +- `abs` builds +- `abs_sim` may also build if the simulator runtime is available + +If you need to replace the default `libabs_kernel.so` with one assembled from +an LLVM IR or LLVM BC path, build that artifact with +`ptoas-vpto-llvm-artifacts` and place it first in `LD_LIBRARY_PATH` when +running `./build/abs`. + +Important: +- the LLVM/VPTO path does not bypass EmitC testcase generation +- `build_llvm_ir_kernel_so.sh` assumes the testcase was already generated from + the EmitC export and reuses its host wrapper/build artifacts + +### 4. Generate golden inputs + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./golden.py +``` + +Expected files: +- `v1.bin` +- `v2.bin` + +For the generated `Abs` testcase, `golden.py` does not emit `golden_v2.bin`, +but `compare.py` expects it. Build the oracle explicitly from the input: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 - <<'PY' +import numpy as np +v1 = np.fromfile('v1.bin', dtype=np.float32) +np.abs(v1).astype(np.float32).tofile('golden_v2.bin') +PY +``` + +Expected additional file: +- `golden_v2.bin` + +## Running + +### NPU run + +Only attempt this on a shell that can actually see `/dev/davinci*`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +``` + +For the repo's automated host-validation flow, prefer the script's default +remote runner: + +```bash +HOST_RUNNER='ssh root@localhost' +``` + +This is already the default in `run_host_npu_validation.sh` / +`run_host_npu_validation_case.sh`, and it is the preferred way to reach a root +context on the local machine when passwordless root SSH is already configured. + +Use that path first instead of assuming `sudo` is available or passwordless. + +If you are not using the repo scripts and your environment explicitly supports +`sudo`, you may still retry manually with: + +```bash +sudo bash -lc ' + cd /tmp/ptoas-npu-validation-run/Abs/abs + source /usr/local/Ascend/cann/set_env.sh >/dev/null 2>&1 + LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +' +``` + +Observed runtime result on this machine for the `Abs` testcase: +- normal user run failed at `aclrtSetDevice(0)` with `507033` +- root-context execution is expected to go through the script default + `ssh root@localhost` path when available +- `python3 ./compare.py` then reported `[INFO] compare passed` + +Observed runtime result on this machine for the VPTO LLVM-path host validation +of `PyPTOIRParser/paged_attention_example_kernel_online_update`: +- `test/npu_validation/scripts/run_host_npu_validation.sh` passed end-to-end +- the replacement kernel library from `build_llvm_ir_kernel_so.sh` was loaded + successfully +- `compare.py` reported `[INFO] compare passed` +- during the LLVM artifact export step, `runop.sh` returned non-zero because + `paged_attention_example_kernel_softmax_prepare` failed in the same sample + batch, but the requested `online_update` LLVM IR was still generated and the + validation flow remained valid + +### Simulator run + +If `abs_sim` links successfully, run it with simulator libraries in `LD_LIBRARY_PATH`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/aarch64-linux/simulator/dav_3510/lib:${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs_sim +``` + +Treat simulator execution as optional. Depending on the local CANN install, the +simulator binary may link successfully but still fail at runtime due to missing +simulator services or runtime symbols. + +## Compare + +After generating `golden_v2.bin` and running the NPU binary, compare with: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./compare.py +``` + +Expected success output: +- `[INFO] compare passed` + +## Known Failure Modes + +- `generate_testcase.py` fails because the input is not a PTOAS EmitC `*-pto.cpp` kernel +- configure fails because `PTO_ISA_ROOT` is unset or points to the wrong checkout +- `abs_sim` fails to link because simulator runtime symbols are missing +- `./build/abs` fails at `aclInit(nullptr)` because the shell does not have usable Ascend runtime access +- non-`sudo` `./build/abs` fails at `aclrtSetDevice(0)` with `507033`, meaning the user context sees the device but cannot open a usable runtime context +- `compare.py` reports `golden_v2.bin` missing because the testcase generation did not create it automatically + +## Reporting Back + +When you use this skill, report: +- the generated testcase directory +- whether `libabs_kernel.so`, `abs`, and `abs_sim` built +- whether `golden.py` generated input bins and whether `golden_v2.bin` had to be created explicitly +- whether NPU execution worked directly or required elevated privileges +- whether `compare.py` passed +- the first concrete blocker for NPU or simulator execution diff --git a/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md new file mode 100644 index 000000000..d24d19c85 --- /dev/null +++ b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md @@ -0,0 +1,319 @@ +--- +name: "ptoas-vpto-llvm-artifacts" +description: "Guide the PTOAS VPTO compile-and-link workflow: inspect VPTO MLIR, export LLVM IR or LLVM bitcode, validate the Bisheng handoff, and assemble device objects, fat objects, or shared kernel libraries. Use when the user asks how to build, export, compile, or link VPTO LLVM-path artifacts for A5." +--- + +# PTOAS VPTO LLVM Artifacts + +Use this skill when the task is specifically about: +- printing or inspecting VPTO intermediate MLIR +- exporting PTOAS A5 kernels as LLVM IR or LLVM bitcode through the VPTO backend +- checking whether the export is textual LLVM IR or real LLVM bitcode +- compiling the exported artifact with `bisheng` +- assembling a device object, fat relocatable object, or shared kernel library from the LLVM path +- helping with an "LLVM IR path build", "LLVM IR path compile", or "VPTO MLIR" request + +This skill answers: +- how to build or export the artifact +- how to hand the artifact to Bisheng +- how to continue from `.ll` / `.bc` to `.o` / `fatobj` / `.so` +- where each stage output is written + +This skill does not answer: +- which `llvm.hivm.*` intrinsic a VPTO op should lower to +- what the authoritative intrinsic name or operand contract is +- whether the repo-local emitter guessed the wrong LLVM IR form + +Those questions belong to `pto-a5-installed-impl-trace`. + +## Strong Rule + +Treat this skill as a compile-and-link workflow guide, not as the authority for +discovering intrinsic mappings. If the task turns into "what should this VPTO +op lower to" or "is this `llvm.hivm.*` form correct", switch to +`pto-a5-installed-impl-trace`. + +This is not the primary entry point for: +- generating `test/npu_validation` testcases +- running on hardware, handling `aclrtSetDevice`, or deciding whether `sudo` is needed +- `golden.py` / `compare.py` result checks +- discovering the authoritative LLVM IR shape for a VPTO op + +If the end goal is runtime validation, use `ptoas-npu-validation-a5` as the main +skill and call this skill only when that flow needs a custom LLVM IR or LLVM BC +kernel artifact. + +## Preconditions + +Before using this path, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is available through CANN `set_env.sh` +- `env.sh` can be sourced from the repo root +- for the fatobj path, you already have a generated testcase directory that + contains a wrapper source such as `abs_kernel.cpp` and a built `launch.cpp.o` + +Load the repo environment before running examples: + +```bash +set +u +source env.sh +set -u +``` + +Use the `set +u` form when the caller shell has `set -u`, because `env.sh` +appends to variables such as `PYTHONPATH` and `LD_LIBRARY_PATH`. + +## Inspect VPTO MLIR + +Use this when you need to look at the VPTO-stage IR before deciding whether to +continue to textual LLVM IR, LLVM bitcode, or the full artifact assembly flow. + +Canonical flag: + +```bash +--vpto-print-ir +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-ir \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Use this output to: +- confirm the lowering has reached the VPTO dialect you expect +- inspect whether a transformation issue appears before LLVM export +- compare the VPTO MLIR path against the later LLVM IR or bitcode output + +## Export Paths + +### LLVM bitcode export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-bc +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-bc \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-bc' \ +./test/samples/runop.sh -t Abs +``` + +Typical outputs: +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp` + +Important: +- the payload is written to `*-pto.cpp` even in bitcode mode +- that file is LLVM bitcode, not C++ source + +Bitcode checks: + +```bash +file /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +xxd -l 16 /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +"$LLVM_ROOT/bin/llvm-dis" /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp -o - | sed -n '1,80p' +``` + +Expected signs: +- `file` reports `LLVM IR bitcode` +- the header starts with `42 43 c0 de` +- `llvm-dis` shows HiVM/LLVM content + +### Textual LLVM IR export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-llvm +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-llvm \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-llvm' \ +./test/samples/runop.sh -t Abs +``` + +Typical output: +- `/tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp` + +Important: +- despite the `.cpp` suffix, this file is textual LLVM IR +- compile it with `-x ir` + +Suggested progression: +- start with `--vpto-print-ir` when the user wants the intermediate VPTO form +- use `--vpto-emit-hivm-llvm` when the user wants textual LLVM IR +- use `--vpto-emit-hivm-bc` when the user wants real LLVM bitcode + +## Compile The Export With Bisheng + +Load the CANN environment first: + +```bash +source /usr/local/Ascend/cann/set_env.sh +``` + +### Compile bitcode to a device object + +Preferred: + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp \ + -o /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.o +``` + +Alternative: +- copy or rename the payload to `.bc` +- compile without relying on the misleading `.cpp` suffix + +### Compile textual LLVM IR to a device object + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o +``` + +Checks: +- keep `-march` and `--cce-aicore-arch` aligned with the intended testcase arch +- for the LLVM IR path, the resulting object should not retain unresolved + `llvm.hivm.*` symbols + +## If You Need The Real Compiler-Expected Intrinsic Shape + +This is outside the main purpose of this skill. + +When a hand-written LLVM IR path fails in instruction selection or appears to +miscompile, use this trace order: + +1. confirm the installed PTO wrapper path first with `pto-a5-installed-impl-trace` +2. generate the normal testcase kernel source through the working emitc path +3. inspect testcase compile flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun that same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` to confirm the wrapper builtin sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survive +6. if builtin names still are not enough, extract the exact frontend-produced + LLVM IR by replaying the `cc1` invocation from `-v` with `-emit-llvm -S` + +Use this when you need to answer questions such as: +- is the intrinsic name correct but the mask form wrong +- did the compiler expect a `plt/pset` result instead of a literal mask +- is the LLVM IR path missing hidden frontend-generated structure or attrs + +This is the preferred way to align repo-local LLVM emission with the real +compiler contract. + +## Assemble Fat Objects And Shared Libraries + +Use this only when the validation flow needs a replacement kernel library built +from the LLVM path. The canonical example below uses the generated `Abs` +testcase, but the pattern is the same for other testcases: take the testcase +wrapper source, embed the device object, pack it with `cce-ld`, then link the +shared kernel library. + +Required testcase artifacts: +- a wrapper source such as `/tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp` +- a built launch object such as + `/tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o` + +### 1. Build the host stub object + +```bash +/usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin/bisheng -cc1 \ + -triple aarch64-unknown-linux-gnu \ + -fcce-is-host \ + -fcce-fatobj-compile \ + -fcce-include-aibinary /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o \ + -fcce-device-module-id a55ab1efc0defeed \ + -fcce-aicore-arch dav-c310-vec \ + -x cce /tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +### 2. Pack the fat relocatable object + +```bash +/usr/local/Ascend/cann-9.0.0/bin/cce-ld \ + /usr/local/Ascend/cann-9.0.0/bin/ld.lld \ + -x \ + -cce-lite-bin-module-id a55ab1efc0defeed \ + -cce-aicore-arch=dav-c310-vec \ + -r \ + -o /tmp/abs_ir_path_artifacts/kernel_fat.o \ + -cce-stub-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/lib/clang/15.0.5/include/cce_stub \ + -cce-install-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin \ + -cce-inputs-number 1 \ + /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +The module id must match between: +- `-fcce-device-module-id` +- `-cce-lite-bin-module-id` + +### 3. Link the shared kernel library + +```bash +mkdir -p /tmp/abs_ir_path_artifacts/link_try +cd /tmp/abs_ir_path_artifacts/link_try +/usr/local/Ascend/cann-9.0.0/bin/bisheng \ + -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link \ + -shared -Wl,-soname,libabs_kernel.so \ + -o libabs_kernel.so \ + /tmp/abs_ir_path_artifacts/kernel_fat.o \ + /tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o +``` + +This skill stops at producing the replacement artifact. To run the testcase +with that library and validate outputs, switch back to `ptoas-npu-validation-a5`. + +## Failure Modes + +Report the first concrete blocker: +- `--vpto-print-ir`, `--vpto-emit-hivm-bc`, or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `--vpto-emit-hivm-bc` or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `env.sh` was not sourced, or failed under `set -u` +- `bisheng` was not found or CANN environment was not loaded +- a bitcode payload was treated as source because it kept a misleading suffix +- the testcase wrapper or `launch.cpp.o` is missing for the fatobj path +- the module ids used for stub creation and `cce-ld` packing do not match + +## Reporting Back + +When you use this skill, report: +- whether the user-facing artifact of interest was VPTO MLIR, textual LLVM IR, or LLVM bitcode +- the exact `ptoas` flags used +- whether the export was VPTO MLIR, LLVM bitcode, or textual LLVM IR +- the exact output path that contains the exported payload +- whether `llvm-dis`, `file`, or direct inspection confirmed the payload type +- whether `bisheng` produced a device object +- whether the flow also produced a fat relocatable object or shared kernel library +- which step was the first blocker, if the full artifact chain did not complete diff --git a/.gitignore b/.gitignore index 44c61b02a..4fbeb8f5e 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,11 @@ dist/ /remote_npu_validation_results*.tsv /npu_validation/ test/samples/**/npu_validation/ +!test/samples/**/npu_validation/ +test/samples/**/npu_validation/* +!test/samples/**/npu_validation/golden.py +!test/samples/**/npu_validation/*/ +!test/samples/**/npu_validation/*/golden.py /tmp_gen* # IDE/editor diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..9ae183956 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/PTO-Gym"] + path = 3rdparty/PTO-Gym + url = git@github.com:PTO-ISA/PTO-Gym.git diff --git a/3rdparty/PTO-Gym b/3rdparty/PTO-Gym new file mode 160000 index 000000000..8a186eae3 --- /dev/null +++ b/3rdparty/PTO-Gym @@ -0,0 +1 @@ +Subproject commit 8a186eae3befc4f1417f4618addbd9e942339acd diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index e19bb6229..53d5c6c88 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -71,15 +71,20 @@ Element type constraints are operation-specific: In addition, memory layout and address space do not change the element type semantics; they only affect placement and access patterns. -### 2.2 `!pto.ptr` +### 2.2 `!pto.ptr` -A pointer to global memory. +A typed pointer. `memorySpace` is optional and defaults to `gm`. | Parameter | Type | Description | |-----------|------|-------------| | `elementType` | `element-type(i1/i8/i16/i32/f16/f32/bf16...)` | Element type pointed to | +| `memorySpace` | `gm` or `ub` | Pointer address space alias (`gm` -> global memory, `ub` -> vector/UB memory) | -**Syntax:** `!pto.ptr` +**Syntax:** `!pto.ptr` or `!pto.ptr` + +Pointer conversions are modeled explicitly with [`pto.castptr`](#ptocastptr). +Between two `!pto.ptr` types, casts are only legal when both pointers stay in +the same PTO memory space. --- @@ -302,6 +307,38 @@ result = ptr + offset // offset is in elements, not bytes %ptr_off = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr ``` +##### `pto.castptr` - Explicit Pointer Cast + +**Summary:** Performs an explicit cast between integer addresses and `!pto.ptr`, +or between two `!pto.ptr` types. + +**Semantics:** + +```mlir +%p0 = pto.castptr %addr : i64 -> !pto.ptr +%p1 = pto.castptr %p0 : !pto.ptr -> !pto.ptr +%addr2 = pto.castptr %p1 : !pto.ptr -> i64 +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `input` | `integer` or `!pto.ptr<...>` | Source value to cast | + +**Results:** `integer` or `!pto.ptr<...>` + +**Constraints & Verification:** + +- Integer-to-integer casts are rejected; use normal integer cast ops instead +- Pointer-to-pointer casts are only legal when source and destination stay in + the same PTO memory space (`gm` or `ub`) +- The operation is pure (no side effects) + +**Hardware Mapping:** + +- No hardware pipeline (representation conversion only) + ##### `pto.make_tensor_view` - Create Tensor View **Summary:** Constructs a global tensor view from a pointer, declaring the physical base and strides (no allocation, no data movement). diff --git a/docs/isa/01-pipeline-sync.md b/docs/isa/01-pipeline-sync.md new file mode 100644 index 000000000..c7b32b677 --- /dev/null +++ b/docs/isa/01-pipeline-sync.md @@ -0,0 +1,591 @@ +# 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +## Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +## Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +## Intra-Core Sync Patterns & Examples + +### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +#### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +#### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +#### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +## Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +## Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` diff --git a/docs/isa/02-dma-copy.md b/docs/isa/02-dma-copy.md new file mode 100644 index 000000000..8d867af08 --- /dev/null +++ b/docs/isa/02-dma-copy.md @@ -0,0 +1,602 @@ +# 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +## Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +## Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +## DMA Transfer Execution + +### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +## Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +## Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` diff --git a/docs/isa/03-vector-load-store.md b/docs/isa/03-vector-load-store.md new file mode 100644 index 000000000..bb840e44b --- /dev/null +++ b/docs/isa/03-vector-load-store.md @@ -0,0 +1,595 @@ +# 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +## Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +## Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +## Contiguous Loads + +### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +## Dual Loads (Deinterleave) + +### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +## Gather (Indexed) Loads + +### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +## Contiguous Stores + +### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +## Dual Stores (Interleave) + +### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +## Scatter (Indexed) Stores + +### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +## Alignment State Stores + +### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +## Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. diff --git a/docs/isa/04-predicate-load-store.md b/docs/isa/04-predicate-load-store.md new file mode 100644 index 000000000..9c3bed11d --- /dev/null +++ b/docs/isa/04-predicate-load-store.md @@ -0,0 +1,135 @@ +# 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +## Predicate Loads + +### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +## Predicate Stores + +### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +## Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/05-materialization-predicate.md b/docs/isa/05-materialization-predicate.md new file mode 100644 index 000000000..e6ee34975 --- /dev/null +++ b/docs/isa/05-materialization-predicate.md @@ -0,0 +1,322 @@ +# 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +## Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +## Scalar Materialization + +### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +## Predicate Generation + +### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +## Predicate Pack/Unpack + +### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +## Predicate Logical Ops + +### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +## Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/06-unary-vector-ops.md b/docs/isa/06-unary-vector-ops.md new file mode 100644 index 000000000..2706ac39b --- /dev/null +++ b/docs/isa/06-unary-vector-ops.md @@ -0,0 +1,172 @@ +# 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +## Arithmetic + +### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +## Transcendental + +### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +## Activation + +### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +## Bitwise + +### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +## Movement + +## Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/07-binary-vector-ops.md b/docs/isa/07-binary-vector-ops.md new file mode 100644 index 000000000..0ab4ae554 --- /dev/null +++ b/docs/isa/07-binary-vector-ops.md @@ -0,0 +1,293 @@ +# 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +## Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +## Arithmetic + +### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +## Bitwise + +### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +## Shift + +### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +## Carry Operations + +### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/08-vec-scalar-ops.md b/docs/isa/08-vec-scalar-ops.md new file mode 100644 index 000000000..9ef60d3cb --- /dev/null +++ b/docs/isa/08-vec-scalar-ops.md @@ -0,0 +1,236 @@ +# 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +## Arithmetic + +### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +## Shift + +### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +## Carry Operations + +### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/09-conversion-ops.md b/docs/isa/09-conversion-ops.md new file mode 100644 index 000000000..efb3a9ed4 --- /dev/null +++ b/docs/isa/09-conversion-ops.md @@ -0,0 +1,252 @@ +# 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +## `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +## `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +#### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +#### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +## `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +## Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/10-reduction-ops.md b/docs/isa/10-reduction-ops.md new file mode 100644 index 000000000..b2fb20894 --- /dev/null +++ b/docs/isa/10-reduction-ops.md @@ -0,0 +1,244 @@ +# 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +## Full Vector Reductions + +### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +## Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +## Prefix Operations + +### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +## Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/11-compare-select.md b/docs/isa/11-compare-select.md new file mode 100644 index 000000000..bc28f2fd1 --- /dev/null +++ b/docs/isa/11-compare-select.md @@ -0,0 +1,182 @@ +# 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +## Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +## Comparison Operations + +### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +## Selection Operations + +### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +## Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +## Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/12-data-rearrangement.md b/docs/isa/12-data-rearrangement.md new file mode 100644 index 000000000..359e7c306 --- /dev/null +++ b/docs/isa/12-data-rearrangement.md @@ -0,0 +1,230 @@ +# 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +## Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +## Interleave / Deinterleave + +### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +## Compress / Expand + +### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +## Pack / Unpack + +### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +## Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +## V2 Interleave Forms + +### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. diff --git a/docs/isa/13-dsa-sfu-ops.md b/docs/isa/13-dsa-sfu-ops.md new file mode 100644 index 000000000..731fa71b2 --- /dev/null +++ b/docs/isa/13-dsa-sfu-ops.md @@ -0,0 +1,229 @@ +# 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +## Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +## Fused Activation Ops (vreg→vreg) + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +## Fused Compute+Convert Ops + +### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + + +## Extended Arithmetic + +### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +## Index Generation + +### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +## Sorting Operations + +### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +## Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +## Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/14-shared-arith.md b/docs/isa/14-shared-arith.md new file mode 100644 index 000000000..6c703dc55 --- /dev/null +++ b/docs/isa/14-shared-arith.md @@ -0,0 +1,99 @@ +# 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +## Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +## Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +## Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +## Typical Patterns + +### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +## Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` diff --git a/docs/isa/15-shared-scf.md b/docs/isa/15-shared-scf.md new file mode 100644 index 000000000..12a637fd7 --- /dev/null +++ b/docs/isa/15-shared-scf.md @@ -0,0 +1,97 @@ +# 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +## Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +## Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +## Typical Patterns + +### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +## Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value diff --git a/docs/release/vpto-spec-v0.1.md b/docs/release/vpto-spec-v0.1.md new file mode 100644 index 000000000..a2949a485 --- /dev/null +++ b/docs/release/vpto-spec-v0.1.md @@ -0,0 +1,4885 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +It is not a dedicated `pto` op. In the PTO micro Instruction, this scope is modeled as a specialized `scf.for` loop annotated with `llvm.loop.aivector_scope`. This gives the compiler a natural structural boundary for identifying the code block that must be lowered into a discrete VF hardware instruction sequence. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} +``` + +### Example: Abs + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} {llvm.loop.aivector_scope} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +#### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +#### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +#### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +#### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +scf.for %arg2 = %c0 to %c1 step %c1 { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} {llvm.loop.aivector_scope} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +#### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit), not an integer vector. + +**Mask Granularity:** + +The mask is 256 bits in length, where each bit controls 1 byte of data. This means mask granularity varies by element type: + +| Element Type | Bits/Element | Mask Bits per Element | +|--------------|--------------|----------------------| +| `f32`/`i32` | 32 | 4 bits | +| `f16`/`bf16`/`i16` | 16 | 2 bits | +| `f8`/`i8` | 8 | 1 bit | + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + +--- + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + +--- + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | +| `BRC_B8/B16/B32` | Broadcast single element | `dst[i] = UB[base]` for all i | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | +| `BLK` | Block load | Blocked access pattern | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. + +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Flush alignment state with scalar offset. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. + +--- + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. +``` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + +--- + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +--- + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + +--- + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + +--- + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | also used for `__VEC_SCOPE__` dummy-loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- the `__VEC_SCOPE__` contract in PTO micro Instruction is modeled as a specialized `scf.for` annotated with `llvm.loop.aivector_scope` +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/release/vpto-spec-v0.2.md b/docs/release/vpto-spec-v0.2.md new file mode 100644 index 000000000..90b632a14 --- /dev/null +++ b/docs/release/vpto-spec-v0.2.md @@ -0,0 +1,5074 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +- `vreg`: `!pto.vreg` + Fixed-width VPTO vector type with total width exactly 256 bytes. +- `mask`: `!pto.mask` + Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. +- `align`: `!pto.align` +- `buf`: buffer-like LLVM pointer type accepted by the dialect +- `buf_like`: `memref<...>` or `!llvm.ptr` for stateless/predicate + `vld*/vst*` families +- `idx`: `index` +- `i32`: `i32` +- `i64`: `i64` + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV_B32` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_*`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B32` | `RV_VLDI` | **9** | +| `DINTLV_B16` | `RV_VLDI` | **9** | +| `DINTLV_B8` | `RV_VLDI` | **9** | +| `BRC_B32` | `RV_VLD` | **9** | +| `BRC_B8` | `RV_VLD` | **9** | +| `BRC_B16` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B32` | `RV_VSTI` | **12** | +| `INTLV_B16` | `RV_VSTI` | **12** | +| `INTLV_B8` | `RV_VSTI` | **12** | +| `UNPK_B8` | `RV_VLD` | **9** | +| `UNPK_B16` | `RV_VLD` | **9** | +| `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B32` | `RV_VSTI` | **9** | +| `NORM_B16` | `RV_VSTI` | **9** | +| `NORM_B8` | `RV_VSTI` | **9** | +| `PK_B32` | `RV_VSTI` | **9** | +| `PK_B16` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8`, `UNPK_B16`, `UNPK_B32` | **9** cycles | +| `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `DINTLV_B16`, `DINTLV_B8` | **9** cycles (same `RV_VLDI` + `dist:DINTLV_*` path as `DINTLV_B32`) | +| `BRC_B32` | **9** cycles | +| `BRC_B8`, `BRC_B16` | **9** cycles (`RV_VLD`) | +| `BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_*`, `DS_*`, `SPLT*` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8`, `NORM_B16`, `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16`, `PK_B32` | **9** cycles | +| `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B16`, `INTLV_B8` | **12** cycles (same interleave store path as `INTLV_B32`) | +| `MRG4CHN_B8`, `MRG2CHN_*` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B32` | Broadcast single element | `dst[i] = UB[base]` for all i | **9** cycles | +| `BRC_B8`, `BRC_B16` | Broadcast first lane element | Same idea at B8/B16 width | **9** cycles | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | **9** cycles | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | **9** cycles | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | **9** cycles | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | **9** cycles | +| `DINTLV_B16`, `DINTLV_B8` | Deinterleave 16-bit / 8-bit | Pair lanes from interleaved UB | **9** cycles | +| `BDINTLV` | Block deinterleave | (see PTO headers for exact tiling) | **9** cycles | +| `BLK` | Block load | Blocked / tiled access pattern (see PTO headers) | **9** cycles (`dist:BRC_BLK` on `RV_VLD`) | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. +- **Latency:** **`DINTLV_B32` → 9** cycles on `RV_VLDI`. **`DINTLV_B16` / `DINTLV_B8` → 9** cycles on `RV_VLDI`. **`BDINTLV` → 9** cycles on `RV_VLDI`. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. +- **Latency:** **9** cycles. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. +- **Latency:** **9** cycles. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | **9** cycles | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | **9** cycles | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | **9** cycles | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. +- **Latency:** **`INTLV_B32` / `INTLV_B16` / `INTLV_B8` → 12** cycles on `RV_VSTI`. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. +- **Latency:** **9** cycles. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. +- **Latency:** **9** cycles. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. +- **Latency:** **9** cycles. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is paired with `f32` +vector compares or selects. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrsqrt` | `RV_VSQRT` / `RV_VDIV` | **17** / **17** | **22** / **22** | — | +| `pto.vrec` | `RV_VDIV` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/release/vpto-spec-v0.3.md b/docs/release/vpto-spec-v0.3.md new file mode 100644 index 000000000..8de281795 --- /dev/null +++ b/docs/release/vpto-spec-v0.3.md @@ -0,0 +1,5349 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/sample.pto b/docs/sample.pto new file mode 100644 index 000000000..956b7ba4c --- /dev/null +++ b/docs/sample.pto @@ -0,0 +1,57 @@ +module attributes {pto.target_arch = "a5"} { + func.func @abs_kernel_2d(%arg0: memref, %arg1: memref) { + %c4096_i64 = arith.constant 4096 : i64 + %c0_i64 = arith.constant 0 : i64 + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = builtin.unrealized_conversion_cast %memspacecast : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.extractvalue %0[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr<6> + %3 = arith.index_castui %c32 : index to i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %4 = arith.muli %3, %c32_i64 : i64 + %5 = arith.muli %c1_i64, %4 : i64 + %6 = arith.muli %5, %c4_i64 : i64 + %7 = arith.muli %4, %c4_i64 : i64 + %8 = arith.muli %3, %c4_i64 : i64 + %c128_i64 = arith.constant 128 : i64 + %9 = llvm.getelementptr %1[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop2_stride_outtoub %6, %c4096_i64 : i64, i64 + a5vm.set_loop1_stride_outtoub %7, %c4096_i64 : i64, i64 + a5vm.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + a5vm.copy_gm_to_ubuf %9, %2, %3, %3, %c0_i64, %3, %8, %c0_i64, %c0_i64, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", data_select_bit = false, layout = "nd", ub_pad = false} : !llvm.ptr<1>, !llvm.ptr<6>, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + a5vm.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %10 = llvm.inttoptr %c4096_i64 : i64 to !llvm.ptr<6> + %c0 = arith.constant 0 : index + %11 = arith.muli %c32, %c32 : index + %c64 = arith.constant 64 : index + %12 = arith.index_castui %11 : index to i32 + pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = a5vm.plt_b32 %arg4 : i32 -> !a5vm.mask, i32 + %17 = a5vm.vlds %2[%arg3] : !llvm.ptr<6> -> !a5vm.vreg<64xf32> + %18 = a5vm.vabs %17, %mask {mode = "MODE_ZEROING"} : !a5vm.vreg<64xf32>, !a5vm.mask -> !a5vm.vreg<64xf32> + a5vm.vsts %18, %10[%arg3], %mask : !a5vm.vreg<64xf32>, !llvm.ptr<6>, !a5vm.mask + scf.yield %scalar_out : i32 + } + } + a5vm.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + a5vm.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + %memspacecast_1 = memref.memory_space_cast %arg1 : memref to memref + %13 = builtin.unrealized_conversion_cast %memspacecast_1 : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.getelementptr %14[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + a5vm.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.copy_ubuf_to_gm %10, %15, %3, %3, %c0_i64, %c32_i64, %8, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", layout = "nd"} : !llvm.ptr<6>, !llvm.ptr<1>, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.pipe_barrier "PIPE_ALL" + return + } +} diff --git a/docs/tilelang-dsl-guide.md b/docs/tilelang-dsl-guide.md new file mode 100644 index 000000000..dfdd41263 --- /dev/null +++ b/docs/tilelang-dsl-guide.md @@ -0,0 +1,2923 @@ +# TileLang Python DSL Guide + +The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. + +The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. + +## Quick Start + +**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +Here's a minimal example of a tile scaling kernel using the new Tile type: + +```python +import pto + +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tile_scale(input_tensor: pto.TensorView, # Input tensor view (shape: 256x128, f32, GM) + output_tensor: pto.TensorView, # Output tensor view (same shape and type) + scale_factor: pto.f32): # Scaling factor + # Access tensor properties + rows, cols = input_tensor.shape # (256, 128) + dtype = input_tensor.element_type # pto.f32 + + # Create a temporary tile in UB for computation + ub_tile = pto.tile((rows, cols), dtype, pto.MemorySpace.UB) + + # Load input tensor from GM to UB using high-level DMA operation + pto.dma_load(input_tensor, ub_tile) + + # Vector computation: scale all elements in the tile + all_mask = pto.make_mask(dtype, PAT.ALL) + + # Process tile in row-major order + for row in range(0, rows): + # Process each row in vector chunks + # Vector width is hardware-defined: 256 bytes / element size + # For f32: 256/4 = 64 lanes, for f16: 256/2 = 128 lanes + vector_lanes = pto.get_lanes(dtype) # Compute vector lanes based on element type (e.g., 64 for f32, 128 for f16) + for col_start in range(0, cols, vector_lanes): + # Load vector using element-indexing syntax (no manual byte calculation) + vec = pto.vlds(ub_tile[row, col_start:]) + + # Scale vector + scaled = pto.vmuls(vec, scale_factor, all_mask) + + # Store result back using element-indexing syntax + pto.vsts(scaled, ub_tile[row, col_start:], all_mask) + + # Store result from UB back to GM output tensor using high-level DMA operation + pto.dma_store(ub_tile, output_tensor) +``` + +This example demonstrates: +1. **TensorView parameters** in kernel declaration +2. **TensorView property access** (shape, element_type) +3. **Tile creation** for temporary buffers +4. **High-level DMA operations** (`dma_load`/`dma_store`) for data movement +5. **Implicit tile→UBRef conversion** in vector load/store operations +6. **Automatic DMA parameter inference** from tensor slices and tile properties + +For an even more concise example showing pure computation on UB tiles (assuming data is already in UB): + +```python +@pto.vkernel(target="a5", op="elementwise", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def ub_tile_computation(a: pto.Tile, # UB tile + b: pto.Tile, # UB tile + c: pto.Tile): # UB tile (output) + dtype = a.element_type + + # All tiles are in UB memory space + all_mask = pto.make_mask(dtype, PAT.ALL) + rows, cols = a.shape + + # Element-wise: c = a + b * 2.0 + for i in range(0, rows * cols, 64): + # Load vectors from UB tiles using element-indexing syntax + vec_a = pto.vlds(a[i:]) # Implicit tile→UBRef with automatic offset calculation + vec_b = pto.vlds(b[i:]) + + # Compute: b * 2.0 + scaled_b = pto.vmuls(vec_b, 2.0, all_mask) + + # Compute: a + scaled_b + result = pto.vadd(vec_a, scaled_b, all_mask) + + # Store result to output tile using element-indexing syntax + pto.vsts(result, c[i:], all_mask) +``` + +## Core Concepts + +### Kernel Declaration + +Kernels are defined using the `@pto.vkernel` decorator with enhanced matching capabilities for PTO operations. The decorator specifies matching criteria for target architecture, operation type, data types, and additional constraints, along with a priority for disambiguation when multiple kernels match. + +#### Basic Syntax + +```python +@pto.vkernel( + target="a5", # Target architecture + op="matmul", # PTO operation name to match + dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures + constraints=[ # Additional constraints + AnyOf(k_dim_aligned_64, continuous_memory), + Not(requires_ub_memory) + ], + priority=100 # Priority for selection +) +def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # kernel implementation +``` + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | Yes | Name of the PTO operation to match (e.g., `"matmul"`, `"conv2d"`, `"add"`). | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | +| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for the kernel to be selected. Can include logical combinations (`AnyOf`, `AllOf`, `Not`). Default: empty list. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | + +#### Type Matching Rules + +The `dtypes` parameter supports flexible type matching: + +1. **Concrete Types**: Exact type matches using DSL scalar types: + - `pto.f16`, `pto.f32`, `pto.bf16` + - `pto.i8`, `pto.i16`, `pto.i32`, `pto.i64` + - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` + +2. **Type Wildcards**: Generic type patterns: + - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) + - `pto.AnyInt`: Matches any integer type (`i8`, `i16`, `i32`, `i64`) + - `pto.AnyType`: Matches any scalar type + - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) + +3. **Type Variables**: Named type variables that enforce consistency within a signature: + ```python + T = pto.TypeVar('T') # Define a type variable + + @pto.vkernel( + target="a5", + op="elementwise", + dtypes=[(T, T, T)], # All three operands must have the same type + constraints=[] + ) + def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: + # x, y, and out must have identical element types + pass + ``` + +4. **Mixed Signatures**: Multiple type signatures for the same operation: + ```python + @pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition + (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition + ] + ) + def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Supports both float and integer types + pass + ``` + +#### Constraint System + +Constraints are compile-time predicates that refine kernel selection. The system supports logical combinations of constraints. + +##### Predefined Constraints + +| Constraint | Description | +|------------|-------------| +| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | +| `continuous_memory` | Operands reside in contiguous memory regions. | +| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | +| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | +| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | +| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | + +##### Logical Constraint Combinators + +| Combinator | Description | Example | +|------------|-------------|---------| +| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | +| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | +| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | + +##### Custom Constraints + +Users can define custom constraints using predicate functions: + +```python +# Define a custom constraint +def large_batch(batch_size: pto.i32) -> pto.Constraint: + """Batch size must be ≥ 1024.""" + return pto.Constraint(lambda op: op.batch_size >= batch_size) + +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[large_batch(1024)] +) +def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized for large batch sizes + pass +``` + +#### Kernel Selection Mechanism + +When a PTO operation needs implementation, the system performs the following matching process: + +1. **Target Filtering**: Select kernels with matching `target` architecture. +2. **Operation Filtering**: Select kernels with matching `op` name. +3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: + - Concrete types must match exactly. + - Wildcard types match according to their category. + - Type variables must be consistent within the signature. +4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. +5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. +6. **Fallback**: If no kernel matches, compilation fails with an error. + +#### Examples + +##### Matmul with Multiple Implementations + +```python +# High-performance kernel for aligned K dimension +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[k_dim_aligned_64], + priority=200 +) +def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized implementation for aligned K + pass + +# General-purpose fallback +@pto.vkernel( + target="a5", + op="matmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=100 +) +def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Generic implementation + pass +``` + +##### Elementwise Operation with Type Polymorphism + +```python +@pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt, pto.AnyInt) + ], + constraints=[broadcastable] +) +def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: + # Single implementation handles both float and integer types + dtype = a.element_type + all_mask = pto.make_mask(dtype, PAT.ALL) + # ... implementation using generic vector operations + pass +``` + +##### Constrained Convolution Kernel + +```python +@pto.vkernel( + target="a5", + op="conv2d", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[ + AllOf( + tensor_rank(4), # NHWC format + static_shape, # No dynamic dimensions + Not(requires_ub_memory) # GM memory preferred + ) + ], + priority=150 +) +def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: + # Optimized for NHWC layout with static shapes + pass +``` + +### Value Model + +The DSL operates on symbolic values, not Python runtime values: +- **Constants**: Python literals that are typed to machine types +- **Operation results**: Values produced by DSL operations +- **Block arguments**: Values introduced by control flow structures + +### Memory Spaces + +The DSL supports different memory spaces: +- `MemorySpace.GM`: Global Memory +- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) + +## Type System + +### Scalar Types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit integer | 8 | +| `pto.i16` | 16-bit integer | 16 | +| `pto.i32` | 32-bit integer | 32 | +| `pto.i64` | 64-bit integer | 64 | +| `pto.f16` | Half precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single precision float | 32 | + +Python literals are automatically typed: +- `bool` → `pto.i1` +- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) +- `float` → `pto.f32` + +For explicit typing, use type constructors: +```python +x = pto.i32(1024) # Explicit i32 constant +y: pto.i32 = 1024 # Type annotation +``` + +### Vector Types + +Vector registers have fixed 256-byte width: + +```python +v64_f32 = pto.vreg(64, pto.f32) # 64 lanes of f32 (64 * 32b = 2048b) +v128_f16 = pto.vreg(128, pto.f16) # 128 lanes of f16 (128 * 16b = 2048b) +``` + +Constraint: `lanes × bitwidth(element_type) = 2048` + +### Typed Masks + +Masks are typed by their bit granularity: + +| DSL Type | VPTO Type | Description | +|----------|-----------|-------------| +| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | +| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | +| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | + +Mask operations must match the vector element family: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +```python +# Correct: f32 vector with b32 mask +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) + +# Error: mismatched mask granularity +mask16 = pto.make_mask(pto.f16, PAT.ALL) +out = pto.vabs(vec_f32, mask16) # Type error! +``` + +### Pointer Types + +Pointers combine element type and memory space: + +```python +from pto import MemorySpace + +ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 +ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 +``` + +The `MemorySpace` enum provides type-safe memory space specification: + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | +| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | + +This replaces string literals (`MemorySpace.GM`/`MemorySpace.UB`) with compile-time checked enums. + +### Pointer Type Aliases + +For clarity in API documentation, the following type aliases are used: + +| Alias | Equivalent Type | Description | +|-------|----------------|-------------| +| `GMPtr` | `ptr(..., MemorySpace.GM)` | Pointer to Global Memory | +| `UBPtr` | `ptr(..., MemorySpace.UB)` | Pointer to Unified Buffer | +| `UBRef` | `Union[MemRefType, UBPtr]` | UB buffer or pointer (accepted by load/store ops) | +| `Tile` | `pto.tile(...)` | Tile buffer with layout and configuration | + +### MemRef Types + +For buffer-like authoring, use memref types: + +```python +buf1d = pto.memref(256, pto.f32, MemorySpace.UB) # 1D: 256-element f32 buffer in UB +buf2d = pto.memref((256, 128), pto.f32, MemorySpace.UB) # 2D: 256x128 f32 buffer in UB +``` + +- **1D shapes**: Use a scalar integer (e.g., `256`) +- **Multi-dimensional shapes**: Use a tuple (e.g., `(256, 128)`) + +MemRefs are used for stateless load/store operations that accept `buf_like` operands in VPTO. + + +### TensorView Types + +TensorView types represent views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +### TensorView Type Definition + +TensorView types are parameterized by shape and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + rows, cols = input_tensor.shape # (dynamic or static) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +**Important Notes:** +- TensorView is a **read-only descriptor** for GM data (though DMA store operations can write to it) +- Shape can be **static** (compile-time constants) or **dynamic** (determined at runtime) +- Strides are expressed in elements, not bytes +- Memory space is always GM (Global Memory) + +### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (2D only in current profile) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding (out-of-bounds access is invalid) | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value (requires `pad_value` parameter) | + +**Usage:** +```python +from pto import PadMode + +# Load with zero padding +pto.dma_load(src_partition, dst_tile, + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0)) + +# Load with first-element padding +pto.dma_load(src_partition, dst_tile, pad_mode=PadMode.PadFirstElem) + +# Load without padding (default) +pto.dma_load(src_partition, dst_tile) # pad_mode=PadMode.PadNull +``` + +### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[row_start:row_end, col_start:col_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +start_row = pto.i32(0) +start_col = pto.i32(0) +dynamic_partition = tensor_view[start_row:start_row+16, start_col:start_col+16] +``` + +**Constraints:** +- Slicing returns a new TensorView representing the logical partition +- The partition must be within the original tensor bounds +- Slices can be static (constant bounds) or dynamic (runtime values) + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. + +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +```python +# Create a tile with shape, element type, and memory space +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB) + +# With explicit configuration +config = pto.tile_config( + b_layout=pto.BLayout.ROW_MAJOR, + s_layout=pto.SLayout.NONE_BOX, + s_fractal_size=pto.i32(16), + pad_value=pto.PadValue.ZERO +) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, config=config) + +# With valid shape (actual data dimensions within tile) +tile = pto.tile((256, 128), pto.f32, MemorySpace.UB, valid_shape=(240, 120)) +``` + +**Important Notes on Shape and Valid Shape:** +- **Static Shape Requirement**: The `shape` parameter must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- **Valid Shape Constraints**: The `valid_shape` parameter can be either static (compile-time constant) or dynamic (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. This allows for variable-sized data within a fixed tile allocation. +- **Default Behavior**: When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | **Static** full tile dimensions (compile-time constant) | +| `element_type` | `Type` | Element data type (e.g., `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space (GM, UB, etc.) | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within tile (can be static/compile-time or dynamic/runtime). Must be ≤ shape in each dimension. | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Configuration + +The tile configuration includes layout and padding information: + +```python +# Layout enums +pto.BLayout.ROW_MAJOR # 0: row-major base layout +pto.BLayout.COL_MAJOR # 1: column-major base layout + +pto.SLayout.NONE_BOX # 0: no secondary layout +pto.SLayout.ROW_MAJOR # 1: row-major secondary layout +pto.SLayout.COL_MAJOR # 2: column-major secondary layout + +pto.PadValue.NULL # 0: no padding +pto.PadValue.ZERO # 1: zero padding +pto.PadValue.MAX # 2: maximum value padding +pto.PadValue.MIN # 3: minimum value padding +``` + +#### Tile Shape Concepts + +- **Static Physical Shape**: The `shape` parameter represents the **static physical dimensions** of the tile allocated in memory. This must be a **compile-time constant** because tile memory allocation is fixed during compilation. The shape determines the total memory footprint and cannot change at runtime. + +- **Valid Shape**: The `valid_shape` parameter represents the logical dimensions of actual data within the tile. It can be either **static** (compile-time constant) or **dynamic** (determined at runtime). It must be less than or equal to the physical `shape` in each dimension. When `valid_shape` is not specified, it defaults to the full `shape`. + +- **Key Distinction**: + - `shape`: **Static, compile-time** - Fixed tile allocation + - `valid_shape`: **Static or Dynamic** - Actual data region (must be ≤ shape) + +- **Constraints**: + - `valid_shape[i] ≤ shape[i]` for each dimension i + - `shape` must be compile-time constants + - `valid_shape` can be compile-time constants or runtime values + +- **Use Cases**: + - Fixed-size tile buffers with variable data (e.g., batch processing with different input sizes) + - Padding scenarios where physical allocation is larger than actual data + - Partial tile utilization in tiled algorithms + +- **Fractal Layout**: The `s_fractal_size` in tile configuration specifies the size of fractal blocks for secondary layout. This is used for optimized memory access patterns in matrix operations. + +- **Padding Behavior**: The `pad_value` determines how out-of-bounds accesses are handled when reading beyond `valid_shape` but within `shape`. Padding values are used for accesses in the padded region (between valid_shape and shape). + +> **⚠️ Important: Shape Constraints** +> +> The tile `shape` must be **compile-time constants**. `valid_shape` can be compile-time constants or determined at runtime, but must satisfy `valid_shape[i] ≤ shape[i]` for all dimensions i. + +### Tile Operations + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(16) +pad = config.pad_value # pto.PadValue.ZERO + +# Dynamic properties +rank = tile.rank # 2 +num_elements = tile.num_elements # 32768 (256 * 128) +valid_elements = tile.valid_elements # 28800 (240 * 120) +``` + +#### Layout and Stride Queries + +```python +# Get layout descriptors +layout_desc = tile.layout_descriptor # Returns layout description object + +# Get strides (in elements) +strides = tile.strides # (128, 1) for row-major 256x128 + +# Get byte strides +byte_strides = tile.byte_strides # (512, 4) for f32 row-major + +# Get base offset (in bytes) +offset = tile.offset # pto.i64(0) or specified offset +``` + +#### Conversion Operations + +Tiles support both explicit and implicit conversion to UBRef. When a tile is used in operations expecting a UBRef (e.g., `pto.vlds`, `pto.vsts`), it is automatically converted. + +```python +# Convert to UBRef (implicit in vector operations) +ub_ref = tile.to_ubref() # Explicit conversion +# or use tile as UBRef directly in vector ops +vec = pto.vlds(tile, offset) # Implicit conversion + +# Convert to typed pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Convert to MemRef (for compatibility) +memref = tile.to_memref() # Returns pto.memref((256, 128), pto.f32, MemorySpace.UB) + +# Extract slice of tile +slice_tile = tile.slice((0, 0), (64, 128)) # 64x128 slice from top-left corner + +# Reshape tile (logical reshape, no data movement) +reshaped = tile.reshape((32768,)) # 1D reshape of 256x128 tile +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, # Tile parameter + output_tile: pto.Tile, # Another tile parameter + scale: pto.f32 +): + # Convert tiles to UBRef for vector operations + ub_in = input_tile.to_ubref() + ub_out = output_tile.to_ubref() + + # Or use tiles directly (implicit conversion) + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + # tile implicitly converts to UBRef in vlds with element-indexing syntax + vec = pto.vlds(input_tile[i, 0:]) # Load from row i, columns 0 to vector_lanes-1 + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) # Store to same position +``` + +#### Tile Creation from Existing Buffers + +```python +# Create tile from existing pointer with shape +ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +tile = pto.tile_from_ptr(ptr, (256, 128), pto.f32) + +# Create tile from memref +memref = pto.memref((256, 128), pto.f32, MemorySpace.UB) +tile = pto.tile_from_memref(memref) + +# Create tile with explicit stride +tile = pto.tile_with_strides((256, 128), pto.f32, MemorySpace.UB, + strides=(256, 1)) # Column-major strides +``` + +## Control Flow + +### Vector Scopes + +The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. + +#### Implicit Scope Inference + +**Note:** The explicit `pto.vecscope()` construct is deprecated. Vector operations are automatically grouped into implicit scopes by the compiler's Scope Inference Pass. + +When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: + +```python +# No explicit vecscope needed - compiler infers scope boundaries +vec = pto.vlds(outer_ptr, offset) +result = pto.vadd(vec, vec, all_mask) +pto.vsts(result, dst_ptr, offset, all_mask) +``` + +The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain. + +**Scope boundary rules:** +1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries +2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries +3. **Explicit strict_vecscope**: User-defined `strict_vecscope` blocks create hard boundaries + +#### Explicit Scope Boundaries with `strict_vecscope` + +For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: + +```python +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Operations inside this block are isolated from outside + # Compiler will not merge operations across this boundary + for i in range(lb, ub, 64): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, all_mask) +``` + +**Use cases for strict_vecscope:** +- Performance optimization: Isolate critical vector computation regions +- Debugging: Create explicit boundaries to isolate vector operations +- Resource management: Control vector register allocation boundaries +- Compatibility: Ensure deterministic scope placement for hardware constraints + +### Loops + +Counted loops use Python's `range` syntax: + +```python +for i in range(lb, ub, step): + # Loop body + mask, rem = pto.make_mask(pto.f32, remaining) + # ... +``` + +Loop-carried state is automatically handled through variable updates within the loop. + +### Conditionals + +`if` statements support value merging: + +```python +flag: pto.i1 = some_condition +step: pto.i32 = 0 + +if flag: + step = pto.i32(64) +else: + step = pto.i32(128) + +# 'step' here is the merged result from both branches +``` + +Variables defined in only one branch are local to that branch. + +## Operations + +The DSL provides operations grouped by functionality. All operations use the `pto.` prefix. Operations are organized by functional families following the VPTO instruction set architecture. + +### Pointer Construction + +Operations for creating and manipulating typed pointers. + +#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` + +**Description**: Creates a pointer with the specified offset and type. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `pto.i64` | Byte offset from base address | +| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +**Example**: +```python +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +``` + +#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` + +**Description**: Adds an offset to an existing pointer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `pto.i64` | Byte offset to add | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer with offset applied | + +**Example**: +```python +next_ptr = pto.addptr(ub_ptr, 4096) +``` + +### Synchronization & Buffer Control + +Operations for pipeline synchronization and buffer management. + +#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Sets a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Waits for a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.pipe_barrier(pipes: PIPE) -> None` + +**Description**: Executes a barrier across specified pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE + +pto.pipe_barrier(PIPE.ALL) +``` + +#### `pto.get_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` + +**Description**: Acquires a buffer for producer-consumer synchronization. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | +| `buf_id` | `pto.i32` | Buffer identifier | +| `mode` | `pto.i32` | Acquisition mode (default: 0) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import SyncOpType + +# Acquire buffer for DMA load operation +pto.get_buf(SyncOpType.TLOAD, 0) +``` + +#### `pto.rls_buf(op_type: SyncOpType, buf_id: pto.i32, mode: pto.i32 = 0) -> None` + +**Description**: Releases a previously acquired buffer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `op_type` | `SyncOpType` | Operation type (e.g., `SyncOpType.TLOAD`) | +| `buf_id` | `pto.i32` | Buffer identifier | +| `mode` | `pto.i32` | Release mode (default: 0) | + +**Returns**: None (side-effect操作) + +**Example**: +```python +from pto import SyncOpType + +# Release buffer for DMA load operation +pto.rls_buf(SyncOpType.TLOAD, 0) +``` + +### Low-level DMA Programming (Legacy) + +**Note**: These low-level DMA programming operations are automatically handled by `pto.dma_load` and `pto.dma_store` in most cases. They expose hardware DMA engine parameters directly and should only be used when the automatic inference provided by the high-level API is insufficient for specific optimization needs. + +This section contains both DMA configuration operations (setting loop strides and sizes) and DMA execution operations (copying data). Prefer the high-level `pto.dma_load` and `pto.dma_store` operations which automatically infer all parameters from TensorView slices and Tile properties. + +#### When to Use Low-level DMA Programming + +Consider using these low-level operations only in the following scenarios: + +1. **Performance micro-optimization**: When specific DMA parameter tuning is required for performance-critical code +2. **Non-standard access patterns**: When TensorView slicing syntax cannot express the desired memory access pattern +3. **Hardware-specific optimizations**: When targeting specific DMA engine characteristics not captured by the high-level API + +For 99% of use cases, `pto.dma_load` and `pto.dma_store` with TensorView slicing provide sufficient control and are much easier to use correctly. + +#### Manual Configuration Example + +```python +# Manual DMA configuration (discouraged for normal use) +pto.set_loop2_stride_outtoub(32, 128) # Outer loop strides +pto.set_loop1_stride_outtoub(1, 32) # Inner loop strides +pto.set_loop_size_outtoub(16, 16) # Transfer size +pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, ...) + +# Equivalent using high-level API (recommended) +pto.dma_load(input_tensor[0:16, 0:16], ub_tile) +# All loop strides and sizes automatically inferred +``` + +#### `pto.set_loop2_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_outtoub(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_outtoub(size0: pto.i64, size1: pto.i64) -> None` + +**Description**: Configures DMA transfer size for GM → UB transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `size0` | `pto.i64` | First dimension size | +| `size1` | `pto.i64` | Second dimension size | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size_outtoub(1, 1) +``` + +#### `pto.set_loop2_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_ubtoout(stride0: pto.i64, stride1: pto.i64) -> None` + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `stride0` | `pto.i64` | First dimension stride | +| `stride1` | `pto.i64` | Second dimension stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_ubtoout(size0: pto.i64, size1: pto.i64) -> None` + +**Description**: Configures DMA transfer size for UB → GM transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `size0` | `pto.i64` | First dimension size | +| `size1` | `pto.i64` | Second dimension size | + +**Returns**: None (side-effect operation) + +#### DMA Execution Operations + +**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. The high-level `pto.dma_load` and `pto.dma_store` operations automatically handle both configuration and execution. + +The following operations provide direct control over DMA transfers but require manual stride and size configuration. Prefer the high-level Tile Data Movement operations for most use cases. + +#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, transpose: pto.i1, pad_left: pto.i64, pad_right: pto.i64, pad_value: pto.i64) -> None` + +**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `GMPtr` | Source GM pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `transpose` | `pto.i1` | Transpose flag | +| `pad_left` | `pto.i64` | Left padding size | +| `pad_right` | `pto.i64` | Right padding size | +| `pad_value` | `pto.i64` | Padding value | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.copy_gm_to_ubuf(gm_ptr, ub_ptr, 0, 32, 128, 0, 0, False, 0, 128, 128) +``` + +#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` + +**Description**: Copies data within Unified Buffer (UB → UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` + +**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `GMPtr` | Destination GM pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.copy_ubuf_to_gm(ub_ptr, gm_ptr, 0, 32, 128, 0, 128, 128) +``` + +### Tile Data Movement Operations + +High-level operations for moving data between TensorView partitions (GM) and Tile buffers (UB), as well as between Tile buffers. These operations **automatically handle all low-level DMA configuration** and provide an intuitive interface based on tile semantics. + +#### Automatic DMA Parameter Inference + +The `pto.dma_load` and `pto.dma_store` operations automatically infer DMA transfer parameters (loop strides, loop sizes) from: + +1. **TensorView slices** - Python slicing syntax captures stride information: + ```python + # Contiguous slice: [0:16, 0:16] + pto.dma_load(input_tensor[0:16, 0:16], ub_tile) + + # Strided slice: [0:64:2, 0:32] → stride=2 in first dimension + pto.dma_load(input_tensor[0:64:2, 0:32], ub_tile) + ``` + +2. **Tile properties** - Layout and memory space determine destination patterns: + ```python + # Row-major vs column-major layouts affect stride computation + row_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.ROW_MAJOR) + col_major_tile = pto.tile((16, 16), pto.f32, pto.MemorySpace.UB, b_layout=pto.BLayout.COL_MAJOR) + ``` + +3. **Transpose and padding requirements** - Specified via operation parameters. + +#### Benefits of Automatic Inference + +- **Simplified API**: No need to manually call `set_loop*_stride_*` and `set_loop_size_*` operations +- **Reduced errors**: Automatic parameter validation and consistency checking +- **Hardware abstraction**: Focus on data movement semantics, not DMA engine details +- **Portable code**: Same TileLang code works across different DMA implementations + +For advanced use cases requiring manual DMA parameter control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. + +#### `pto.dma_load(src: TensorView, dst: Tile, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0, init_out_buffer: bool = False) -> None` + +**Description**: Loads data from a TensorView partition (GM) into a Tile buffer (UB). This maps to `pto.copy_gm_to_ubuf` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `TensorView` | Source tensor view partition (must be in GM) | +| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | +| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | +| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | +| `left_padding` | `Index` | Left padding element count | +| `right_padding` | `Index` | Right padding element count | +| `init_out_buffer` | `bool` | Initialize output buffer before loading | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Destination tile must have `memory_space = MemorySpace.UB` +- Element types of source and destination must have same bitwidth +- Source partition shape must match destination tile valid shape (after accounting for padding) + +**Example**: +```python +# Load a 16x16 partition into a UB tile +pto.dma_load(input_tensor[0:16, 0:16], ub_tile) + +# Load with zero padding +pto.dma_load(input_tensor[0:16, 0:16], ub_tile, + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0), + left_padding=2, + right_padding=2) +``` + +#### `pto.dma_store(src: Tile, dst: TensorView, pad_mode: PadMode = PadMode.PadNull, pad_value: ScalarType = None, left_padding: Index = 0, right_padding: Index = 0) -> None` + +**Description**: Stores data from a Tile buffer (UB) to a TensorView partition (GM). This maps to `pto.copy_ubuf_to_gm` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile buffer (must be in UB memory space) | +| `dst` | `TensorView` | Destination tensor view partition (must be in GM) | +| `pad_mode` | `PadMode` | Padding mode (PadNull, PadFirstElem, PadValue) | +| `pad_value` | `ScalarType` | Padding value (required if `pad_mode == PadValue`) | +| `left_padding` | `Index` | Left padding element count | +| `right_padding` | `Index` | Right padding element count | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Source tile must have `memory_space = MemorySpace.UB` +- Element types of source and destination must have same bitwidth +- Source tile valid shape must match destination partition shape (after accounting for padding) + +**Example**: +```python +# Store a UB tile to a GM partition +pto.dma_store(ub_tile, output_tensor[0:16, 0:16]) + +# Store with padding +pto.dma_store(ub_tile, output_tensor[0:16, 0:16], + pad_mode=PadMode.PadValue, + pad_value=pto.f32(0.0), + left_padding=1, + right_padding=1) +``` + +#### `pto.dma_copy(src: Tile, dst: Tile, src_offset: tuple[Index, Index] = (0, 0), dst_offset: tuple[Index, Index] = (0, 0), copy_shape: tuple[Index, Index] = None) -> None` + +**Description**: Copies data between Tile buffers within Unified Buffer (UB → UB). This maps to `pto.copy_ubuf_to_ubuf` operation in VPTO IR. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile buffer (must be in UB memory space) | +| `dst` | `Tile` | Destination tile buffer (must be in UB memory space) | +| `src_offset` | `tuple[Index, Index]` | Offset within source tile (row, col) in elements | +| `dst_offset` | `tuple[Index, Index]` | Offset within destination tile (row, col) in elements | +| `copy_shape` | `tuple[Index, Index]` | Shape of region to copy (rows, cols) in elements. If None, copies the maximum valid region starting from offsets. | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Both tiles must have `memory_space = MemorySpace.UB` +- Element types of source and destination must match +- Source and destination regions must be within tile valid shapes + +**Example**: +```python +# Copy entire tile +pto.dma_copy(src_tile, dst_tile) + +# Copy subregion: copy 8x8 block from (2,2) in src to (0,0) in dst +pto.dma_copy(src_tile, dst_tile, + src_offset=(2, 2), + dst_offset=(0, 0), + copy_shape=(8, 8)) +``` + +**Note**: These high-level operations automatically handle DMA stride and size configuration based on tile shapes, layouts, and offsets. For low-level control, see the [Low-level DMA Programming (Legacy)](#low-level-dma-programming-legacy) section. + +#### VPTO IR Mapping + +The high-level DMA operations in TileLang DSL map to corresponding operations in VPTO IR: + +| TileLang DSL Operation | VPTO IR Operation | Description | +|------------------------|-------------------|-------------| +| `pto.dma_load` | `pto.copy_gm_to_ubuf` | Loads data from GM tensor view to UB tile buffer | +| `pto.dma_store` | `pto.copy_ubuf_to_gm` | Stores data from UB tile buffer to GM tensor view | +| `pto.dma_copy` | `pto.copy_ubuf_to_ubuf` | Copies data between UB tile buffers | + +These mappings allow the TileLang compiler to generate efficient VPTO IR code while providing a higher-level, more intuitive API for developers. The compiler automatically handles the conversion between Tile/TensorView abstractions and the low-level pointer/stride representation required by VPTO IR operations. + + +### Address Generation Syntax Sugar + +To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. + +#### Indexing Syntax + +The syntax supports two indexing modes for different operations: + +1. **Vector-range indexing** (for vector load/store operations): + - **Row-major layout (default)**: `tile[row_index, col_start:]` + - `row_index`: Row index (0-based) + - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column + - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type + + - **Column-major layout**: `tile[row_start:, col_index]` + - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row + - `col_index`: Column index (0-based) + - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise + + - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) + - `start:`: Starting element index followed by colon + +2. **Single-element indexing** (for scalar load operations like `pto.vsld`): + - **Row-major layout (default)**: `tile[row_index, col_index]` + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + + - **Column-major layout**: `tile[row_index, col_index]` (same syntax) + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Same syntax as row-major; the layout determines how the offset is computed + + - **1D tile indexing**: `tile[pos]` + - `pos`: Element index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + +#### Vector Width Calculation + +The number of elements loaded/stored in a single vector operation is determined by: + +``` +vector_lanes = 256 // element_size_bytes(element_type) +``` + +**Convenience API**: Use `pto.get_lanes(dtype)` to compute vector lanes for a given element type (e.g., `pto.get_lanes(pto.f32)` returns 64, `pto.get_lanes(pto.f16)` returns 128). + +Where `element_size_bytes` is: +- 1 byte for `i8` +- 2 bytes for `i16`, `f16`, `bf16` +- 4 bytes for `i32`, `f32` +- 8 bytes for `i64` + +#### Offset Computation + +The byte offset is automatically computed based on tile layout: + +- **Row-major layout** (`BLayout.ROW_MAJOR`): + ``` + offset = (row_index * stride_row + col_start) * element_size_bytes + ``` + where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). + +- **Column-major layout** (`BLayout.COL_MAJOR`): + - For syntax `tile[row_start:, col_index]`: + ``` + offset = (col_index * stride_col + row_start) * element_size_bytes + ``` + - For backward compatibility with traditional offset calculation: + ``` + offset = (col_start * stride_col + row_index) * element_size_bytes + ``` + where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. + +**Note**: +- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). +- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. +- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. + +#### Constraints + +1. **Boundary checks**: The requested region must be within tile bounds: + - **For vector-range indexing** (`:` syntax): + - **Row-major layout** (`tile[row_index, col_start:]`): + - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` + - **Column-major layout** (`tile[row_start:, col_index]`): + - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` + - **1D tile indexing**: `tile[start:]` + - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + - **For single-element indexing** (no `:` syntax): + - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) + - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + +2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. + +3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. + +4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. + +#### Supported Operations + +The indexing syntax is supported for all vector load and store operations with the following syntax mapping: + +- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): + - Load operations: `vlds`, `vldas`, `vldus`, `vplds`, `vldx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstx2` + +- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): + - Load operations: `vsld` (scalar load with broadcast) + +#### Examples + +The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. + +```python +# 2D tile indexing (row-major layout) +vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[i, j:], mask) # Store vector with mask + +# 1D tile indexing +vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store vector with mask + +# Dual load with indexing +vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) + +# Aligned load with indexing +vec = pto.vldas(tile[i, j:], align) + +# Scalar load (broadcast) +vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector +``` + +#### Comparison with Manual Offset Calculation + +**Traditional approach (error-prone):** +```python +# Manual byte offset calculation for f32 tile +rows, cols = tile.shape +row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 +col_offset = j * 4 +offset = row_offset + col_offset +vec = pto.vlds(tile, offset) +``` + +**New syntax (type-safe):** +```python +# Automatic offset calculation +vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type +``` + +The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). + +### Vector Load Operations + +Operations for loading data from memory into vector registers. + +#### `pto.vlds(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vlds(tile[row, col:]) -> VRegType` +#### `pto.vlds(tile[start:]) -> VRegType` + +**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Traditional byte-offset syntax +vec = pto.vlds(ub_ptr, lane * 256) + +# New element-indexing syntax +vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 +vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 + +# Generic kernel that works for both f16 and f32 +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): # vector_lanes computed from element type + # No manual byte calculation needed! + vec = pto.vlds(src[i, j:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, dst[i, j:], all_mask) +``` + +#### `pto.vldas(buf: UBRef, offset: Index, align: pto.align) -> VRegType` +#### `pto.vldas(tile[row, col:], align: pto.align) -> VRegType` +#### `pto.vldas(tile[start:], align: pto.align) -> VRegType` + +**Description**: Aligned vector load with explicit alignment carrier. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `align` | `pto.align` | Alignment specification | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `align` | `pto.align` | Alignment specification | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vldas(ub_ptr, offset, align) + +# Element-indexing syntax +vec = pto.vldas(tile[i, j:], align) +vec = pto.vldas(tile[k:], align) +``` + +#### `pto.vldus(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vldus(tile[row, col:]) -> VRegType` +#### `pto.vldus(tile[start:]) -> VRegType` + +**Description**: Unaligned vector load. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vldus(ub_ptr, offset) + +# Element-indexing syntax +vec = pto.vldus(tile[i, j:]) +vec = pto.vldus(tile[k:]) +``` + +#### `pto.vplds(buf: UBRef, offset: Index, pred: MaskType) -> VRegType` +#### `pto.vplds(tile[row, col:], pred: MaskType) -> VRegType` +#### `pto.vplds(tile[start:], pred: MaskType) -> VRegType` + +**Description**: Predicated vector load stateless. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `pred` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `pred` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vplds(ub_ptr, offset, mask) + +# Element-indexing syntax +vec = pto.vplds(tile[i, j:], mask) +vec = pto.vplds(tile[k:], mask) +``` + +#### `pto.vldx2(buf1: UBRef, buf2: UBRef, offset: Index) -> (VRegType, VRegType)` +#### `pto.vldx2(tile1[row, col:], tile2[row, col:]) -> (VRegType, VRegType)` +#### `pto.vldx2(tile1[start:], tile2[start:]) -> (VRegType, VRegType)` + +**Description**: Dual vector load from two buffers. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf1` | `UBRef` | First buffer or pointer | +| `buf2` | `UBRef` | Second buffer or pointer | +| `offset` | `Index` | Byte offset (applied to both buffers) | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column | +| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column | +| _or_ | | | +| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index | +| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | Vector from first buffer | +| `vec2` | `VRegType` | Vector from second buffer | + +**Examples**: +```python +# Byte-offset syntax +vec1, vec2 = pto.vldx2(ub_ptr1, ub_ptr2, offset) + +# Element-indexing syntax +vec1, vec2 = pto.vldx2(tile_a[i, j:], tile_b[i, j:]) +vec1, vec2 = pto.vldx2(tile_a[k:], tile_b[k:]) +``` + +#### `pto.vsld(buf: UBRef, offset: Index) -> VRegType` +#### `pto.vsld(tile[row, col]) -> VRegType` +#### `pto.vsld(tile[pos]) -> VRegType` + +**Description**: Scalar load to vector (broadcast scalar to all lanes). Supports both byte-offset and element-indexing syntax. The element-indexing syntax loads a single element (not a vector) and broadcasts it to all lanes. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `UBRef` | Buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Vector with scalar broadcast to all lanes | + +**Examples**: +```python +# Byte-offset syntax +vec = pto.vsld(ub_ptr, offset) + +# Element-indexing syntax +vec = pto.vsld(tile[i, j]) # Load single element at (i,j) and broadcast +vec = pto.vsld(tile[k]) # Load single element at position k and broadcast +``` + +### Predicate Operations + +Operations for creating and manipulating typed masks. + +**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. + +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Creates an 8-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Used with `i8` vector operations + +**Example**: +```python +mask8 = pto.make_mask(pto.i8, PAT.ALL) +``` + +#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Creates a 16-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +``` + +#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Creates a 32-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_EVEN`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations + +**Example**: +```python +mask32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +#### `pto.pge_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` + +**Description**: Creates 8-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Vector element type must be `i8` or compatible + +#### `pto.pge_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` + +**Description**: Creates 16-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Vector element type must be `f16`/`bf16`/`i16` + +#### `pto.pge_b32(vec: VRegType, scalar: ScalarType) -> pto.mask_b32` + +**Description**: Creates 32-bit mask where vector elements ≥ scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Vector element type must be `f32`/`i32` + +**Example**: +```python +mask = pto.pge_b32(vec_f32, pto.f32(0.0)) +``` + +#### `pto.plt_b8(vec: VRegType, scalar: ScalarType) -> pto.mask_b8` + +**Description**: Creates 8-bit mask where vector elements < scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +#### `pto.plt_b16(vec: VRegType, scalar: ScalarType) -> pto.mask_b16` + +**Description**: Creates 16-bit mask where vector elements < scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +#### `pto.plt_b32(vec: VRegType, scalar: ScalarType) -> (pto.mask_b32, pto.i32)` + +**Description**: Creates 32-bit mask where vector elements < scalar, returns mask and remaining count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (element type must match mask granularity) | +| `scalar` | `ScalarType` | Scalar comparison value | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | +| `remaining` | `pto.i32` | Remaining element count | + +**Example**: +```python +mask, remaining = pto.plt_b32(vec_f32, pto.f32(10.0)) +``` + +#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` + +**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (e.g., `pto.MaskPattern.PAT_ALL`, `pto.MaskPattern.PAT_VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Generated mask with appropriate granularity | +| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | + +**Constraints**: +- The `element_type` must be one of: `f32`, `i32`, `f16`, `bf16`, `i16`, `i8` +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`, 16-bit for `f16`/`bf16`/`i16`, 8-bit for `i8` +- The function infers the operation mode from the `value` parameter type at compile time: + - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) + - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) + +**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: +- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) +- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) + +**Example**: +```python +# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 +mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) + +# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 +mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) + +# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 +mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) + +# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 +mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) + +# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 +mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) + +# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 +mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) + +# Type annotations help clarify expected parameter types +remaining: pto.i32 = 1024 +mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing +mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode +``` + +#### `pto.ppack(mask: MaskType) -> pto.i32` + +**Description**: Packs mask bits into a 32-bit integer. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `packed` | `pto.i32` | Packed mask bits | + +#### `pto.punpack(packed: pto.i32) -> MaskType` + +**Description**: Unpacks 32-bit integer to mask (granularity determined by context). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `packed` | `pto.i32` | Packed mask bits | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Unpacked mask | + +#### `pto.pnot(mask: MaskType) -> MaskType` + +**Description**: Logical negation of mask bits. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `negated` | `MaskType` | Negated mask | + +#### `pto.psel(mask: MaskType, true_val: ScalarType, false_val: ScalarType) -> ScalarType` + +**Description**: Selects between two scalar values based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Selection mask | +| `true_val` | `ScalarType` | Value selected when mask bit is 1 | +| `false_val` | `ScalarType` | Value selected when mask bit is 0 | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `ScalarType` | Selected scalar value | + +### Unary Vector Operations + +Element-wise unary operations on vector registers. + +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Absolute value of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Absolute values | + +**Constraints**: +- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) + +**Example**: +```python +abs_vec = pto.vabs(vec_f32, mask32) +``` + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential values | + +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Natural logarithm of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Natural logarithm values | + +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Square root of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Square root values | + +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: ReLU activation (max(0, x)) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated values | + +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitwise NOT of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise NOT values | + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex addition of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex addition result | + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex maximum of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex maximum result | + +### Binary Vector Operations + +Element-wise binary operations on vector registers. + +#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise addition of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum of vectors | + +**Example**: +```python +sum_vec = pto.vadd(vec_a, vec_b, mask32) +``` + +#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise subtraction of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference of vectors | + +#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise multiplication of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Product of vectors | + +#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise division of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Quotient of vectors | + +#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise maximum | + +#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise minimum | + +#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift left (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift right (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +### Vector-Scalar Operations + +Operations between vectors and scalars. + +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector multiplied by scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar multiplier | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Scaled vector | + +**Example**: +```python +scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) +``` + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector plus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar addend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Maximum values | + +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Minimum values | + +#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU activation (max(αx, x)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Alpha coefficient | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Leaky ReLU activated values | + +#### `pto.vshls(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift left by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshrs(vec: VRegType, shift: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector shift right by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +### Carry & Select Operations + +Operations with carry propagation and selection. + +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, carry_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` + +**Description**: Vector addition with carry input and output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `carry_in` | `ScalarType` | Input carry bit | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum vector | +| `carry_out` | `ScalarType` | Output carry bit | + +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, borrow_in: ScalarType, mask: MaskType) -> (VRegType, ScalarType)` + +**Description**: Vector subtraction with borrow input and output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `borrow_in` | `ScalarType` | Input borrow bit | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | +| `borrow_out` | `ScalarType` | Output borrow bit | + +#### `pto.vsel(mask: MaskType, true_vec: VRegType, false_vec: VRegType) -> VRegType` + +**Description**: Vector select based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Selection mask | +| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | +| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +**Example**: +```python +result = pto.vsel(mask32, scaled_vec, original_vec) +``` + +### Data Rearrangement + +Operations for rearranging data within vectors. + +#### `pto.pdintlv_b8(mask: pto.mask_b8) -> pto.mask_b8` + +**Description**: Deinterleave 8-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b8` | Input 8-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b8` | Deinterleaved mask | + +#### `pto.pintlv_b16(mask: pto.mask_b16) -> pto.mask_b16` + +**Description**: Interleave 16-bit mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.mask_b16` | Input 16-bit mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `pto.mask_b16` | Interleaved mask | + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Interleave two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Interleaved vector | + +#### `pto.vdintlv(vec: VRegType, mask: MaskType) -> (VRegType, VRegType)` + +**Description**: Deinterleave vector into two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | First deinterleaved vector | +| `vec2` | `VRegType` | Second deinterleaved vector | + +### Conversion & Special Operations + +Type conversion and specialized operations. + +#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: str) -> VRegType` + +**Description**: Truncate/round floating-point vector elements to integer-valued +floating-point results under an explicit predicate mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask; granularity must match element width | +| `rnd` | `str` | Round mode: `R`, `A`, `F`, `C`, or `Z` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Rounded result with the same floating-point element type | + +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType) -> VRegType` + +**Description**: Type conversion of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `to_type` | `Type` | Target element type | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Converted vector | + +#### `pto.vbitsort(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitonic sort of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +#### `pto.vmrgsort4(vec1: VRegType, vec2: VRegType, vec3: VRegType, vec4: VRegType, mask: MaskType) -> VRegType` + +**Description**: 4-way merge sort of vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `vec3` | `VRegType` | Third input vector | +| `vec4` | `VRegType` | Fourth input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +### Stateless Store Operations + +Operations for storing data from vector registers to memory (stateless). + +#### `pto.vsts(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType) -> None` + +**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer (UB memory space) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements + +**Examples**: +```python +# Byte-offset syntax +pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) + +# Element-indexing syntax +pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 + +# In a generic kernel +@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_store(src: pto.Tile, dst: pto.Tile): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): + vec = pto.vlds(src[i, j:]) + pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation +``` + +#### `pto.psts(mask: MaskType, buf: UBRef, offset: Index) -> None` +#### `pto.psts(mask: MaskType, tile[row, col:]) -> None` +#### `pto.psts(mask: MaskType, tile[start:]) -> None` + +**Description**: Predicate store to buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +#### `pto.vsst(scalar: ScalarType, buf: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` + +**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, buf1: UBRef, buf2: UBRef, offset: Index, mask: MaskType) -> None` +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[row, col:], tile2[row, col:], mask: MaskType) -> None` +#### `pto.vstx2(vec1: VRegType, vec2: VRegType, tile1[start:], tile2[start:], mask: MaskType) -> None` + +**Description**: Dual vector store to two buffers. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `buf1` | `UBRef` | First destination buffer | +| `buf2` | `UBRef` | Second destination buffer | +| `offset` | `Index` | Byte offset (applied to both buffers) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `tile1[row, col:]` | `Tile` with indexing | First 2D tile with row index and starting column (vector-width range) | +| `tile2[row, col:]` | `Tile` with indexing | Second 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First vector to store | +| `vec2` | `VRegType` | Second vector to store | +| `tile1[start:]` | `Tile` with indexing | First 1D tile with starting element index (vector-width range) | +| `tile2[start:]` | `Tile` with indexing | Second 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vsta(vec: VRegType, buf: UBRef, offset: Index, align: pto.align, mask: MaskType) -> None` +#### `pto.vsta(vec: VRegType, tile[row, col:], align: pto.align, mask: MaskType) -> None` +#### `pto.vsta(vec: VRegType, tile[start:], align: pto.align, mask: MaskType) -> None` + +**Description**: Aligned vector store with explicit alignment carrier. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `align` | `pto.align` | Alignment specification | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +### Stateful Store Operations + +Operations for storing data with stateful semantics. + +#### `pto.pstu(mask: MaskType, buf: UBRef, offset: Index) -> None` + +**Description**: Predicate stateful store. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Mask to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | + +**Returns**: None (side-effect operation) + +#### `pto.vstu(vec: VRegType, buf: UBRef, offset: Index, mask: MaskType) -> None` + +**Description**: Vector stateful store. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | Destination buffer or pointer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstus(align_in: AlignType, offset: i32, vec: VRegType, buf: UBRef) -> AlignType` + +**Description**: No-post unaligned vector store with scalar offset. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming unaligned-store state | +| `offset` | `i32` | Stream advance offset in elements | +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | UB destination base pointer | + +**Returns**: Updated align-state token for a later flush op such as `pto.vstas`. + +#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: UBRef, mode: str) -> AlignType` + +**Description**: Unaligned vector store using the SPR-AR-driven stateful form. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming unaligned-store state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `UBRef` | UB destination base pointer | +| `mode` | `str` | `POST_UPDATE` or `NO_POST_UPDATE` | + +**Returns**: Updated align-state token for a later flush op such as `pto.vstar`. + +## Examples + +### Simple Vector Copy + +```python +@pto.vkernel(...) +def vector_copy(src: pto.memref(256, pto.f32, MemorySpace.UB), + dst: pto.memref(256, pto.f32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for offset in range(0, 256, 64): + vec = pto.vlds(src, offset) + pto.vsts(vec, dst, offset, all_mask) +``` + +### Conditional Computation + +```python +@pto.vkernel(...) +def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM), + threshold: pto.f32): + # ... setup ... + + with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): + for i in range(0, 1024, 64): + vec = pto.vlds(vin, i) + + # Compare with threshold + mask = pto.pge_b32(vec, thresh) + + # Scale values above threshold + scaled = pto.vmuls(vec, pto.f32(2.0), mask) + + # Keep original values below threshold + result = pto.vsel(mask, scaled, vec) + + pto.vsts(result, vout, i, all_mask) +``` + +### Loop with Carry + +```python +@pto.vkernel(...) +def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), + dst: pto.ptr(pto.i32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.i32, PAT.ALL) + carry = pto.i32(0) + + for i in range(0, 256, 64): + vec = pto.vlds(src, i) + result, carry = pto.vaddcs(vec, carry, all_mask) + pto.vsts(result, dst, i, all_mask) +``` + +## Common Errors + +### Typed Mask Mismatch + +``` +Error: f32 vector operation cannot consume mask_b16 +``` + +**Solution:** Ensure mask granularity matches vector element size: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +### Strict Scope Implicit Capture + +``` +Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly +``` + +**Solution:** Pass all required values in the capture list: + +```python +# Wrong: +with pto.strict_vecscope() as (): + vec = pto.vlds(ub_in, offset) # ub_in from outer scope + +# Correct: +with pto.strict_vecscope(ub_in) as (ub): + vec = pto.vlds(ub, offset) +``` + +### Untyped Loop Carried State + +``` +Error: loop-carried value must have explicit machine type +``` + +**Solution:** Add type annotations to loop-carried variables: + +```python +# Wrong: +remaining = 1024 # Plain Python int +for i in range(0, N, step): + mask, remaining = pto.make_mask(pto.f32, remaining) + +# Correct: +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + +## Compatibility Notes + +The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: + +1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` +2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` +3. **MemRef support**: Does not yet support `pto.memref()` types +4. **Operation coverage**: Implements only a subset of operations + +When implementing new code, follow this specification. The experimental implementation will be updated to match over time. + +## Next Steps + +- Explore the ISA documentation in `docs/isa/` for detailed operation semantics +- Check `test/samples/` for example kernels +- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification + +For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. diff --git a/docs/tilelang-dsl-syntax-sugar-proposals.md b/docs/tilelang-dsl-syntax-sugar-proposals.md new file mode 100644 index 000000000..8a60466d9 --- /dev/null +++ b/docs/tilelang-dsl-syntax-sugar-proposals.md @@ -0,0 +1,404 @@ +# TileLang DSL Syntax Sugar Proposals + +## Overview + +This document proposes syntax sugar enhancements for the TileLang Python DSL to improve programming ergonomics while maintaining close correspondence with the underlying VPTO IR. The current DSL design closely mirrors VPTO instructions, which can lead to verbose and error-prone code. These proposals aim to provide higher-level abstractions that compile down to the existing VPTO operations. + +## Current Usability Challenges + +### 1. **Low-Level Pointer Operations** +```python +# Current: manual byte offset management +ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +ub_out = pto.castptr(4096, pto.ptr(pto.f32, MemorySpace.UB)) +next_ptr = pto.addptr(ub_ptr, 4096) +``` +**Problem**: Users must manage byte offsets and memory spaces manually. + +### 2. **Verbose Copy Operations** +The `pto.copy_ubuf_to_ubuf` operation has 7 parameters: +- `src_offset`, `src_stride0`, `src_stride1` +- `dst_offset`, `dst_stride0`, `dst_stride1` + +**Problem**: Correctly setting stride parameters is error-prone, especially for multi-dimensional data. + +### 3. **Precise Mask Type Matching** +```python +# Must ensure mask granularity matches element type +mask32 = pto.pset_b32("PAT_ALL") # f32 requires b32 mask +mask16 = pto.pset_b16("PAT_ALL") # f16 requires b16 mask +``` +**Problem**: Type error messages are not intuitive and easy to confuse. + +### 4. **Strict Vector Scope Requirements** +```python +# strict_vecscope requires explicit capture of all variables +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Can only use captured variables +``` +**Problem**: Increases boilerplate code, especially when multiple variables need capture. + +### 5. **Manual Synchronization Management** +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` +**Problem**: Easy to forget synchronization or use wrong event IDs. + +### 6. **Byte Offsets vs. Element Indices** +```python +# Need to calculate byte offsets +vec = pto.vlds(ub_ptr, lane * 256) # Assuming f32, 4 bytes per element +``` +**Problem**: Users must understand underlying memory layout. + +## Proposed Syntax Sugar Enhancements + +### 1. **Array View Abstraction** + +#### Current API +```python +# Low-level pointer operations +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +vec = pto.vlds(ub_ptr, 64 * 4) # Load 64th f32 element +``` + +#### Proposed Syntax Sugar +```python +# Create array views +ub_array = pto.ub_array(256, pto.f32, base_offset=0) # 256-element f32 UB array +gm_array = pto.gm_array(1024, pto.f32, src) # GM pointer array view + +# Element access with automatic offset calculation +element = ub_array[64] # Get 64th element (auto-calculates byte offset) +slice = ub_array[128:256] # Slice operation + +# Array assignment (compiles to appropriate copy operations) +ub_array[0:64] = gm_array[0:64] # Compiles to copy_gm_to_ubuf + +# Multi-dimensional arrays +ub_2d = pto.ub_array((256, 128), pto.f32) # 2D array +row = ub_2d[32, :] # Row slice +col = ub_2d[:, 64] # Column slice +``` + +#### Implementation Notes +- `ub_array[64]` → `pto.vlds(ub_ptr, 64 * sizeof(f32))` +- `ub_array[0:64] = gm_array[0:64]` → Appropriate `copy_gm_to_ubuf` call with stride calculations +- Array views are compile-time constructs with no runtime overhead + +### 2. **Simplified Copy Operations** + +#### Current API +```python +pto.copy_gm_to_ubuf(src, dst, 0, 32, 128, 0, 0, False, 0, 128, 128) +``` + +#### Proposed Syntax Sugar +```python +# Full array copy +pto.copy_gm_to_ub(gm_array, ub_array) + +# Slice copy with automatic stride calculation +pto.copy_gm_to_ub(gm_array[0:64], ub_array[128:192]) + +# Copy with element count +pto.copy_gm_to_ub(gm_array, ub_array, count=64) + +# Transpose copy +pto.copy_gm_to_ub(gm_array, ub_array, transpose=True) + +# Multi-dimensional copy with automatic stride inference +pto.copy_gm_to_ub(gm_2d[0:32, :], ub_2d[:, 0:64]) + +# Chained operations +(pto.copy_gm_to_ub(gm_array, ub_array) + .then(pto.copy_ub_to_ub(ub_array, ub_temp)) + .then(pto.copy_ub_to_gm(ub_temp, dst_array))) +``` + +### 3. **Automatic Mask Inference** + +#### Current API +```python +# Must specify mask type explicitly +mask32 = pto.pset_b32("PAT_ALL") +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) +``` + +#### Proposed Syntax Sugar +```python +# Automatic mask type inference +mask = pto.pset("PAT_ALL") # Inferred as mask_b32 for f32 vectors +out = pto.vabs(vec_f32, mask) # Type-safe, auto-matched + +# Vector method syntax (more Pythonic) +out = vec_f32.abs(mask="PAT_ALL") +out = vec_f32.add(other_vec, mask=pto.pset("PAT_EVEN")) +out = vec_f32.max(scalar, mask="PAT_ALL") + +# Mask creation from comparison +mask = vec_f32 >= pto.f32(0.0) # Creates appropriate mask_b32 +mask = vec_f32 < threshold # Auto-infers mask type + +# Mask operations with auto-typing +combined = mask1 & mask2 # Bitwise AND with type preservation +inverted = ~mask # Logical NOT +``` + +### 4. **Simplified Synchronization Primitives** + +#### Current API +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +# ... computation ... +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### Proposed Syntax Sugar +```python +# Context manager for automatic synchronization +with pto.sync_between(PIPE.MTE2, PIPE.V, event=EVENT.ID0): + # set_flag called on entry, wait_flag on exit + pto.copy_gm_to_ub(src, dst) + compute_block() + +# Decorator for function-level synchronization +@pto.synchronized(from_pipe=PIPE.MTE2, to_pipe=PIPE.V) +def compute_block(): + # Automatic synchronization before and after + pass + +# Pipeline synchronization chain +with pto.pipeline([ + (PIPE.MTE2, PIPE.V, EVENT.ID0), + (PIPE.V, PIPE.MTE3, EVENT.ID1), + (PIPE.MTE3, PIPE.S, EVENT.ID2) +]): + # Multi-stage synchronization + stage1() + stage2() + stage3() +``` + +### 5. **Element-Level Indexing Operations** + +#### Current API +```python +# Byte offset calculation required +vec = pto.vlds(ub_ptr, lane * 256) # Need to know f32 is 4 bytes +``` + +#### Proposed Syntax Sugar +```python +# Element-level indexing +vec = pto.vlde(ub_array, lane) # Automatic byte offset calculation +pto.vste(vec, ub_array, lane) # Element-level store + +# Array view methods +vec = ub_array.load_element(lane) +ub_array.store_element(lane, vec) + +# Batch operations +vectors = ub_array.load_elements([0, 64, 128, 192]) +ub_array.store_elements([256, 320, 384], vectors) + +# Strided access +stride = ub_array.load_stride(start=0, end=1024, step=64) +``` + +### 6. **Type Inference Simplification** + +#### Current API +```python +# Explicit type annotations required +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + +#### Proposed Syntax Sugar +```python +# Automatic type inference for constants +remaining = pto.constant(1024) # Inferred as i32 or i64 from context +step = pto.constant(64, type=pto.i32) # Explicit type specification + +# Typed range with automatic inference +for i in pto.range(0, 1024, 64): # i automatically gets correct machine type + # i is pto.i32 + +# Function argument type inference +@pto.vkernel +def kernel(x): # Type inferred from usage + return x * pto.constant(2) # x type inferred from multiplication + +# Variable type inference from operations +result = pto.constant(10) + pto.constant(20) # result is pto.i32 +``` + +### 7. **More Flexible Vector Scopes** + +#### Current API +```python +# Explicit capture required +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + for i in range(lb, ub, step): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, mask) +``` + +#### Proposed Syntax Sugar +```python +# Automatic variable capture +with pto.vector_scope(): + # Variables used in scope are automatically captured + for i in pto.range(start, end, step): + vec = src_array.load_element(i) + dst_array.store_element(i, vec.abs()) + +# Decorator for vectorized functions +@pto.vectorize +def compute_element(src, dst, index): + vec = src.load_element(index) + dst.store_element(index, vec.abs()) + +# Apply vectorized function across range +pto.vector_map(compute_element, src_array, dst_array, range(0, 1024, 64)) + +# Lambda support +pto.vector_map(lambda x: x.abs(), src_array, dst_array) +``` + +### 8. **Built-in Utility Functions** + +#### Common Pattern Encapsulation +```python +# Vector map/reduce operations +result = pto.vector_map(abs, src_array, dst_array) # Element-wise mapping +sum = pto.vector_reduce(add, array) # Reduction +max_val = pto.vector_reduce(max, array) # Maximum reduction + +# Vector zip/unzip +zipped = pto.vector_zip(src1, src2, dst) # Interleave +unzipped1, unzipped2 = pto.vector_unzip(src, dst1, dst2) # Deinterleave + +# Mathematical functions +result = pto.vector_sin(array) +result = pto.vector_exp(array) +result = pto.vector_relu(array) +result = pto.vector_sigmoid(array) + +# Statistical operations +mean = pto.vector_mean(array) +variance = pto.vector_variance(array) +min_val, max_val = pto.vector_minmax(array) + +# Linear algebra (small-scale) +dot_product = pto.vector_dot(vec1, vec2) +norm = pto.vector_norm(array) +``` + +## Implementation Strategy + +These syntax sugar enhancements can be implemented through: + +1. **Python Decorators and Context Managers**: For synchronization and vector scopes +2. **Wrapper Classes**: `UBArray`, `GMArray`, `Vector` classes that encapsulate low-level operations +3. **Operator Overloading**: Support for `[]`, `:`, arithmetic operators on wrapper classes +4. **Type Inference System**: Context-based machine type inference +5. **Compile-time Transformation**: Conversion of high-level syntax to low-level VPTO operations before IR generation + +## Compatibility with VPTO IR + +**Key Principle**: All syntax sugar must ultimately lower to existing VPTO operations. + +### Lowering Examples + +| Syntax Sugar | VPTO IR Equivalent | +|--------------|-------------------| +| `ub_array[64]` | `pto.vlds(ub_ptr, 64 * sizeof(f32))` | +| `pto.copy_gm_to_ub(src_array, dst_array)` | Appropriate `copy_gm_to_ubuf` call with calculated strides | +| `with pto.sync_between(...):` | `set_flag` + `wait_flag` pair | +| `mask = vec_f32 >= pto.f32(0.0)` | `pto.pge_b32(vec_f32, pto.f32(0.0))` | +| `vec_f32.abs(mask="PAT_ALL")` | `pto.vabs(vec_f32, pto.pset_b32("PAT_ALL"))` | + +## Prioritization + +### High Priority (Immediate Value) +1. Array view abstraction +2. Simplified copy operations +3. Automatic mask inference + +### Medium Priority (Significant Ergonomics Improvement) +4. Element-level indexing +5. Type inference simplification +6. Flexible vector scopes + +### Low Priority (Advanced Features) +7. Enhanced synchronization primitives +8. Built-in utility functions + +## Migration Path + +The existing low-level API will remain available for performance-critical code or direct VPTO IR correspondence. Syntax sugar will be provided as an optional layer that can be mixed with low-level operations. + +```python +# Mixed usage example +@pto.vkernel +def mixed_kernel(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM)): + # Low-level: manual pointer setup + ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) + + # High-level: array view for computation + ub_array = pto.ub_array(256, pto.f32, base_ptr=ub_in) + + # Mixed: low-level copy, high-level computation + pto.copy_gm_to_ubuf(src, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + with pto.vector_scope(): + for i in pto.range(0, 256, 64): + vec = ub_array.load_element(i) + result = vec.abs(mask="PAT_ALL") + ub_array.store_element(i, result) + + # Low-level: copy back + pto.copy_ubuf_to_gm(ub_in, dst, 0, 32, 128, 0, 128, 128) +``` + +## Next Steps + +1. **Prototype Implementation**: Start with array view abstraction and simplified copy operations +2. **User Feedback**: Gather feedback from performance engineers on the proposed syntax +3. **Gradual Rollout**: Implement enhancements in phases, starting with high-priority items +4. **Documentation**: Update DSL guide with syntax sugar examples and migration guides +5. **Testing**: Ensure all syntax sugar correctly lowers to VPTO IR and maintains performance + +These enhancements will significantly improve the TileLang DSL's usability while maintaining the close correspondence with underlying VPTO IR that performance engineers require. + +1. 软件流水线(Software Pipelining)的表达成本 +在 NPU 上写 Vector 级算子,最难的往往不是数值计算,而是利用 UB (Unified Buffer) 进行 Double/Multi-Buffering(乒乓缓存),并手动排布内存搬运与计算的流水线。 + +现状挑战:如果开发者全靠手写 set_flag、wait_flag,以及手动维护 Ping-Pong 缓冲的偏移量,代码会迅速膨胀且极易死锁或读写冲突。 + +优化建议:DSL 在保留底层原语的同时,可以提供稍微高级一点的流水线抽象。例如,引入 pto.CircularBuffer(tile, num_stages=2) 的概念,让开发者可以专注于“当前 stage 的计算”,而由底层生成器自动完成不同 stage 的指针轮转和 Flag 同步。 + +2. Python 宿主变量 vs MLIR SSA 变量的心智模型边界 +因为 DSL 的本质是用 Python 元编程来生成 MLIR(静态图),开发者在写代码时很容易混淆“Python 运行期的值”和“NPU 运行期的值”。 + +现状挑战:手册中提到“变量的自动合并”(比如 if 分支产生合并),这涉及到复杂的 SSA 转换。特别是在 for 循环中,**循环携带状态(Loop-carried state)**的处理往往是个痛点。如果开发者在循环外定义了一个 Python 列表或字典,在循环内去修改它,这在生成 MLIR 的 scf.for 时是无法正确映射的。 + +优化建议:需要有极其明确的类型系统提示或语法边界,强制区分编译期求值的变量(Meta-variables)和生成的 MLIR Value。可以考虑借鉴 Triton 的方式,提供类似 tl.constexpr 的装饰或类型,让开发者清楚哪些分支在生成 MLIR 时会被静态展开,哪些会真正生成 scf.if。 + +3. 地址计算(Address Generation)的易错性 +即使是对底层开发者,手动计算字节偏移也是痛苦且容易出 Bug 的。 + +现状挑战:i * cols * 4 这种强依赖 f32 占用 4 字节的硬编码,在泛型算子开发中会带来负担(比如想写一个同时兼容 f16 和 f32 的模板算子)。 + +优化建议:提供基于语义的视图(View)操作。保留控制力不代表必须算字节。可以提供类似 tile.get_vector_slice(row_idx, vec_idx) 的接口,它在内部自动 Emit(发射)对应的 MLIR 乘法和加法指令来计算 offset。这不仅防呆,还能让生成的 MLIR 结构更规范。 + +4. Mask 的隐式推导(针对边界处理) +NPU 算子经常要处理尾部不对齐的数据(Tail processing)。 + +优化建议:虽然底层需要具体的 Mask 寄存器配置(如 PAT_ALL),但在 for 循环的最后一步边界处理时,能否提供一个类似 pto.make_mask(remaining_elements) 的宏/内联函数?让它在生成 MLIR 时,自动展开为对应的硬件 plt_b32 等指令,这样可以大幅减少手写冗长边界判断的样板代码。 \ No newline at end of file diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md new file mode 100644 index 000000000..c20942b75 --- /dev/null +++ b/docs/vpto-spec.md @@ -0,0 +1,996 @@ +# PTO micro Instruction Spec — Merged Draft (A5) + +> **Status:** DRAFT for review +> **Base:** [vpto-spec.md](https://github.com/mouliangyu/PTOAS/blob/feature-vpto-backend/docs/vpto-spec.md) (2026-03-20) +> **Updated:** 2026-03-27 + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](isa/03-vector-load-store.md) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference +# Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is available in the linked files. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](isa/02-dma-copy.md) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](isa/06-unary-vector-ops.md) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](isa/07-binary-vector-ops.md) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](isa/08-vec-scalar-ops.md) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](isa/14-shared-arith.md) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](isa/15-shared-scf.md) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +*For detailed semantics, C-style pseudocode, and CCE mappings, see the individual group documentation files.* + +--- + +## Appendix: Discussion Points + +### Part I + +1. **mem_bar as pto op:** Should `pto.mem_bar` be a formal pto dialect op, or is there an existing mechanism? +2. **UB size parameterization:** Is 256KB always fixed, or should spec allow for architecture variants? +3. **MERGING predication:** Intentionally omitted (SW-emulated, perf overhead). Revisit if needed later. + +### Part II + +1. **Predication in C semantics:** Should every op's C code explicitly show the `if (mask[i])` guard, or assume all-active and note predication separately? +2. **VLane terminology:** Using "VLane" instead of "DataBlock" — confirm this naming is preferred. + +### Part 3A + +1. **pto.vdupi:** Is this distinct from `pto.vdup` with an immediate operand, or can `pto.vdup` handle both? +2. **Predicate ops (pand/por/pxor and predicate movement forms):** These need MLIR op definitions and verifier rules. Confirm priority. + +### Part 3B + +1. **Section 10 removals:** 4 interleave ops removed (not on A5). If multi-arch support is needed later, these would need conditional inclusion. + +### Part 3C + +2. **Store dist family completeness:** `vsts` currently covers `NORM`, `1PT`, `PK`, `PK4`, `MRG4CHN`, and `MRG2CHN`, while `vstsx2` covers `INTLV`. Confirm whether the surface constraints for these families are already sufficiently clear and complete. +3. **vcvt width-changing pattern:** The even/odd + `vor` pattern for forms such as `f32 -> f16` is the standard compiler lowering. Confirm this is the intended representation in the spec. +4. **Stateful store ops (Section 14):** These are complex with SSA state threading. Are they all needed for A5, or can some be simplified? diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 193fe44c5..97352aaa6 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -49,7 +49,7 @@ def TileBufOrMemRef : def ScalarPtrOrMemRef : TypeConstraint< CPred<"::mlir::pto::isScalarPtrOrMemRef($_self)">, - "Ptr or MemRef in GM">; + "Ptr or GM MemRef">; def ScalarType : AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">; @@ -72,6 +72,8 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VPTOOps.td" + //===----------------------------------------------------------------------===// // Pointer/View Ops (for your front-end IR) //===----------------------------------------------------------------------===// @@ -101,6 +103,31 @@ def AddPtrOp : PTO_Op<"addptr", [ }]; } +def CastPtrOp : PTO_Op<"castptr", [Pure]> { + let summary = "Cast between integer and !pto.ptr, or between !pto.ptr types"; + let description = [{ + Performs an explicit pointer-domain cast. + + Supported cases: + - integer -> !pto.ptr + - !pto.ptr -> integer + - !pto.ptr -> !pto.ptr + - memref<..., space> -> !pto.ptr (extract the aligned base ptr) + + Pointer-to-pointer casts must stay within the same PTO memory space. Cross + space casts such as gm <-> ub are rejected by the verifier. + }]; + + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + //===----------------------------------------------------------------------===// // Scalar pointer load/store //===----------------------------------------------------------------------===// @@ -1968,8 +1995,10 @@ def SetFlagOp : PTO_Op<"set_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -1981,8 +2010,10 @@ def WaitFlagOp : PTO_Op<"wait_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2016,6 +2047,9 @@ def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> { // Buffer-ID Synchronization (A5) //===----------------------------------------------------------------------===// +def PTO_PipeLikeAttr + : AnyAttrOf<[PTO_PipeEventTypeAttr, PTO_SyncOpTypeAttr, PTO_PipeAttr]>; + def GetBufOp : PTO_Op<"get_buf"> { let summary = "Acquire a buffer-id token for a sync op type (A5)"; let description = [{ @@ -2032,7 +2066,7 @@ def GetBufOp : PTO_Op<"get_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2041,8 +2075,10 @@ def GetBufOp : PTO_Op<"get_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2056,7 +2092,7 @@ def RlsBufOp : PTO_Op<"rls_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2065,8 +2101,10 @@ def RlsBufOp : PTO_Op<"rls_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 21050a2d3..03bd0ddc0 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -16,18 +16,25 @@ include "mlir/IR/AttrTypeBase.td" include "PTO/IR/PTODialect.td" include "PTO/IR/PTOAttrs.td" -// ---- !pto.ptr ---- +// ---- !pto.ptr ---- def PtrType : TypeDef { let mnemonic = "ptr"; let parameters = (ins - "mlir::Type":$elementType + "mlir::Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace ); - let assemblyFormat = "`<` $elementType `>`"; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; let builders = [ TypeBuilder<(ins "Type":$elementType), [{ - return Base::get($_ctxt, elementType); + return Base::get($_ctxt, elementType, + mlir::pto::AddressSpaceAttr::get($_ctxt, + mlir::pto::AddressSpace::GM)); + }]>, + TypeBuilder<(ins "Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace), [{ + return Base::get($_ctxt, elementType, memorySpace); }]> ]; } @@ -233,3 +240,5 @@ def AsyncEventType : TypeDef { let mnemonic = "async_event"; let summary = "Opaque async DMA event handle type"; } + +include "PTO/IR/VPTOTypeDefs.td" diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td new file mode 100644 index 000000000..dc416e868 --- /dev/null +++ b/include/PTO/IR/VPTOOps.td @@ -0,0 +1,1449 @@ +//===- VPTOOps.td - PTO low-level operations ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VPTOOPS +#define MLIR_DIALECT_PTO_IR_VPTOOPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def PTO_VectorType : Type($_self)">, + "PTO low-level vector type">; +def PTO_MaskTypeConstraint : Type($_self)">, + "PTO low-level mask type">; +def PTO_B8MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB8()">, + "PTO low-level b8 mask type">; +def PTO_B16MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB16()">, + "PTO low-level b16 mask type">; +def PTO_B32MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB32()">, + "PTO low-level b32 mask type">; +def PTO_AlignTypeConstraint : Type($_self)">, + "PTO low-level align type">; + +def PTO_BufferType : Type< + CPred<"::llvm::isa<::mlir::pto::PtrType>($_self)">, + "pointer-like buffer type">; +def PTO_BufferLikeType : AnyTypeOf<[AnyMemRef, PTO_BufferType], + "memref or pointer-like buffer type">; + +def VecScopeOp : PTO_Op<"vecscope", [SingleBlock, NoTerminator]> { + let summary = "Structured region container for one VPTO vector scope"; + let description = [{ + `pto.vecscope` marks a structured vector-scope interval without overloading + a dummy carrier loop with scope metadata. Lowering and emission passes may + use the region boundary to preserve loop shape while treating the enclosed + body as one VPTO vector interval. + }]; + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = "$body attr-dict"; +} + +def StrictVecScopeOp : PTO_Op<"strict_vecscope", [SingleBlock, NoTerminator, + IsolatedFromAbove]> { + let summary = "Structured VPTO vector scope with explicit captures only"; + let description = [{ + `pto.strict_vecscope` is the strict form of `pto.vecscope`. Values used by + the body must be passed explicitly through op operands and corresponding + block arguments; implicit SSA capture from the surrounding scope is + rejected. + }]; + + let arguments = (ins Variadic:$captures); + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + `(` $captures `)` $body attr-dict `:` functional-type($captures, results) + }]; +} + +class PTO_BinaryI64ConfigOp : PTO_Op { + let arguments = (ins + I64:$first, + I64:$second + ); + + let results = (outs); + + let assemblyFormat = [{ + $first `,` $second attr-dict `:` type($first) `,` type($second) + }]; +} + +def PTO_SetLoop2StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_outtoub">; +def PTO_SetLoop1StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_outtoub">; +def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; +def PTO_SetLoop2StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_ubtoout">; +def PTO_SetLoop1StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_ubtoout">; +def PTO_SetLoopSizeUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop_size_ubtoout">; + +def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$left_padding_count, + I64:$right_padding_count, + I1:$data_select_bit, + I64:$l2_cache_ctl, + I64:$gm_stride, + I64:$ub_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $left_padding_count `,` $right_padding_count `,` $data_select_bit `,` $l2_cache_ctl `,` $gm_stride `,` $ub_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($left_padding_count) `,` + type($right_padding_count) `,` type($data_select_bit) `,` type($l2_cache_ctl) `,` type($gm_stride) `,` type($ub_stride) + }]; +} + +def PTO_CopyUbufToUbufOp : PTO_Op<"copy_ubuf_to_ubuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_VldsOp : PTO_Op<"vlds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VldsPostOp : PTO_Op<"vlds_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result, + PTO_BufferLikeType:$updated_source); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) `,` type($updated_source) + }]; +} + +def PTO_Vldsx2Op : PTO_Op<"vldsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($low) `,` type($high) + }]; +} + +def PTO_VldasOp : PTO_Op<"vldas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source + ); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_InitAlignOp : PTO_Op<"init_align", []> { + let arguments = (ins); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + +def PTO_SprclrOp : PTO_Op<"sprclr", []> { + let arguments = (ins StrAttr:$spr); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $spr attr-dict + }]; +} + +def PTO_VldusOp : PTO_Op<"vldus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_AlignTypeConstraint:$align + ); + + let results = (outs + PTO_VectorType:$result, + PTO_AlignTypeConstraint:$updated_align + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $align attr-dict `:` type($source) `,` type($align) `->` type($result) `,` type($updated_align) + }]; +} + +def PTO_UvldOp : PTO_Op<"uvld", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VbrOp : PTO_Op<"vbr", [Pure]> { + let arguments = (ins AnyType:$value); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value attr-dict `:` type($value) `->` type($result) + }]; +} + +def PTO_VdupOp : PTO_Op<"vdup", [Pure]> { + let arguments = (ins + AnyType:$input, + PTO_MaskTypeConstraint:$mask, + OptionalAttr:$position + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PsetB8Op : PTO_Op<"pset_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PsetB16Op : PTO_Op<"pset_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +// NOTE: The op families introduced below are intentionally marked as +// unvalidated scaffolding. They are added to preserve missing CCE builtin +// semantics at the dialect layer, but they have not yet been validated through +// PTO lowering or end-to-end sample execution. +def PTO_PsetB32Op : PTO_Op<"pset_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB8Op : PTO_Op<"pge_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB16Op : PTO_Op<"pge_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB32Op : PTO_Op<"pge_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PltB8Op : PTO_Op<"plt_b8", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B8MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB16Op : PTO_Op<"plt_b16", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B16MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB32Op : PTO_Op<"plt_b32", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B32MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +class PTO_MaskUnaryOp : PTO_Op { + let arguments = (ins PTO_MaskTypeConstraint:$input); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PpackOp : PTO_MaskUnaryOp<"ppack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PunpackOp : PTO_MaskUnaryOp<"punpack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PnotOp : PTO_Op<"pnot", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PselOp : PTO_Op<"psel", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PandOp : PTO_Op<"pand", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PorOp : PTO_Op<"por", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PxorOp : PTO_Op<"pxor", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PldsOp : PTO_Op<"plds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PldiOp : PTO_Op<"pldi", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PstiOp : PTO_Op<"psti", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VabsOp : PTO_Op<"vabs", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VexpOp : PTO_UnaryVecOp<"vexp">; +def PTO_VlnOp : PTO_UnaryVecOp<"vln">; +def PTO_VsqrtOp : PTO_UnaryVecOp<"vsqrt">; +def PTO_VnegOp : PTO_UnaryVecOp<"vneg">; +def PTO_VreluOp : PTO_UnaryVecOp<"vrelu">; +def PTO_VnotOp : PTO_UnaryVecOp<"vnot">; +def PTO_VcaddOp : PTO_UnaryVecOp<"vcadd">; +def PTO_VcmaxOp : PTO_UnaryVecOp<"vcmax">; +def PTO_VcminOp : PTO_UnaryVecOp<"vcmin">; +def PTO_VcgaddOp : PTO_UnaryVecOp<"vcgadd">; +def PTO_VcgmaxOp : PTO_UnaryVecOp<"vcgmax">; +def PTO_VcgminOp : PTO_UnaryVecOp<"vcgmin">; +def PTO_VcpaddOp : PTO_UnaryVecOp<"vcpadd">; + +class PTO_BinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VaddOp : PTO_BinaryVecOp<"vadd">; +def PTO_VsubOp : PTO_BinaryVecOp<"vsub">; +def PTO_VmulOp : PTO_BinaryVecOp<"vmul">; +def PTO_VdivOp : PTO_BinaryVecOp<"vdiv">; +def PTO_VmaxOp : PTO_BinaryVecOp<"vmax">; +def PTO_VminOp : PTO_BinaryVecOp<"vmin">; +def PTO_VandOp : PTO_BinaryVecOp<"vand">; +def PTO_VorOp : PTO_BinaryVecOp<"vor">; +def PTO_VxorOp : PTO_BinaryVecOp<"vxor">; + +def PTO_VaddcOp : PTO_Op<"vaddc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcOp : PTO_Op<"vsubc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VaddcsOp : PTO_Op<"vaddcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcsOp : PTO_Op<"vsubcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VshlOp : PTO_BinaryVecOp<"vshl">; +def PTO_VshrOp : PTO_BinaryVecOp<"vshr">; + +def PTO_VselOp : PTO_Op<"vsel", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpOp : PTO_Op<"vcmp", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask `,` $cmp_mode attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpsOp : PTO_Op<"vcmps", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $scalar `,` $mask `,` $cmp_mode attr-dict `:` type($src) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +class PTO_PredicatePairReorderOp + : PTO_Op { + let arguments = (ins + operandTy:$lhs, + operandTy:$rhs + ); + let results = (outs + operandTy:$low, + operandTy:$high + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_PdintlvB8Op : PTO_PredicatePairReorderOp<"pdintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PdintlvB16Op : PTO_PredicatePairReorderOp<"pdintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PdintlvB32Op : PTO_PredicatePairReorderOp<"pdintlv_b32", + PTO_B32MaskTypeConstraint>; + +def PTO_PintlvB8Op : PTO_PredicatePairReorderOp<"pintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PintlvB16Op : PTO_PredicatePairReorderOp<"pintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PintlvB32Op : PTO_PredicatePairReorderOp<"pintlv_b32", + PTO_B32MaskTypeConstraint>; + +class PTO_VecScalarOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar attr-dict `:` type($input) `,` type($scalar) `->` type($result) + }]; +} + +class PTO_VecScalarMaskedOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar `,` $mask attr-dict `:` type($input) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VtrcOp : PTO_Op<"vtrc", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask, + StrAttr:$round_mode + ); + let results = (outs PTO_VectorType:$result); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def PTO_VcvtOp : PTO_Op<"vcvt", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + OptionalAttr:$rnd, + OptionalAttr:$sat, + OptionalAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_VciOp : PTO_Op<"vci", [Pure]> { + let arguments = (ins + AnyInteger:$index, + OptionalAttr:$order + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $index attr-dict `:` type($index) `->` type($result) + }]; +} + +def PTO_VbitsortOp : PTO_Op<"vbitsort", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source, + PTO_BufferType:$indices, + Index:$repeat_times + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source `,` $indices `,` $repeat_times attr-dict `:` type($destination) `,` + type($source) `,` type($indices) `,` type($repeat_times) + }]; +} + +def PTO_Vmrgsort4Op : PTO_Op<"vmrgsort4"> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source0, + PTO_BufferType:$source1, + PTO_BufferType:$source2, + PTO_BufferType:$source3, + I64:$count, + I64:$config + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source0 `,` $source1 `,` $source2 `,` $source3 `,` $count `,` $config + attr-dict `:` type($destination) `,` type($source0) `,` type($source1) `,` type($source2) `,` + type($source3) `,` type($count) `,` type($config) + }]; +} + +def PTO_Vgather2Op : PTO_Op<"vgather2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + Index:$active_lanes + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $active_lanes attr-dict `:` type($source) `,` type($offsets) `,` type($active_lanes) `->` type($result) + }]; +} + +def PTO_VgatherbOp : PTO_Op<"vgatherb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +// NOTE: Unvalidated new gather/select/interleave-family abstractions. Added to +// cover CCE builtin families not yet exercised through end-to-end PTO seams. +def PTO_Vgather2BcOp : PTO_Op<"vgather2_bc", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VmulsOp : PTO_VecScalarMaskedOp<"vmuls">; +def PTO_VaddsOp : PTO_VecScalarMaskedOp<"vadds">; +def PTO_VmaxsOp : PTO_VecScalarMaskedOp<"vmaxs">; +def PTO_VminsOp : PTO_VecScalarMaskedOp<"vmins">; +def PTO_VlreluOp : PTO_VecScalarMaskedOp<"vlrelu">; +def PTO_VshlsOp : PTO_VecScalarMaskedOp<"vshls">; +def PTO_VshrsOp : PTO_VecScalarMaskedOp<"vshrs">; + +def PTO_VstsOp : PTO_Op<"vsts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) + }]; +} + +def PTO_VstsPostOp : PTO_Op<"vsts_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs PTO_BufferLikeType:$updated_destination); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) `->` type($updated_destination) + }]; +} + +def PTO_VscatterOp : PTO_Op<"vscatter", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferType:$destination, + PTO_VectorType:$offsets, + Index:$active_lanes + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offsets `,` $active_lanes attr-dict `:` type($value) `,` type($destination) `,` type($offsets) `,` type($active_lanes) + }]; +} + +def PTO_PstsOp : PTO_Op<"psts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_CopyUbufToGmOp : PTO_Op<"copy_ubuf_to_gm", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$reserved, + I64:$burst_dst_stride, + I64:$burst_src_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $reserved `,` $burst_dst_stride `,` $burst_src_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($reserved) `,` + type($burst_dst_stride) `,` type($burst_src_stride) + }]; +} + +// NOTE: Unvalidated new x2 / pair / align-store-family abstractions. Added to +// reflect CCE builtin families but not yet end-to-end validated. +def PTO_VselrOp : PTO_Op<"vselr", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VsqzOp : PTO_UnaryVecOp<"vsqz">; + +def PTO_VusqzOp : PTO_Op<"vusqz", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $mask attr-dict `:` type($src) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VpackOp : PTO_Op<"vpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VsunpackOp : PTO_Op<"vsunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VzunpackOp : PTO_Op<"vzunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_Vselrv2Op : PTO_Op<"vselrv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VintlvOp : PTO_Op<"vintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_VdintlvOp : PTO_Op<"vdintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_Vintlvv2Op : PTO_Op<"vintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vdintlvv2Op : PTO_Op<"vdintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmullOp : PTO_Op<"vmull", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($low) `,` type($high) + }]; +} + +def PTO_VmulaOp : PTO_Op<"vmula", [Pure]> { + let arguments = (ins + PTO_VectorType:$acc, + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $acc `,` $lhs `,` $rhs `,` $mask attr-dict `:` type($acc) `,` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnmaskedBinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VpreluOp : PTO_UnmaskedBinaryVecOp<"vprelu">; +def PTO_VexpdiffOp : PTO_Op<"vexpdiff", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_VectorType:$max, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $max `,` $part attr-dict `:` type($input) `,` type($max) `->` type($result) + }]; +} + +def PTO_VaxpyOp : PTO_Op<"vaxpy", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + AnyType:$alpha + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $alpha attr-dict `:` type($src0) `,` type($src1) `,` type($alpha) `->` type($result) + }]; +} + +def PTO_VaddreluconvOp : PTO_Op<"vaddreluconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmulconvOp : PTO_Op<"vmulconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vstsx2Op : PTO_Op<"vstsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$low, + PTO_VectorType:$high, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $low `,` $high `,` $destination `[` $offset `]` `,` $dist `,` $mask attr-dict `:` type($low) `,` type($high) `,` type($destination) `,` type($offset) `,` type($mask) + }]; +} + +def PTO_VsldbOp : PTO_Op<"vsldb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($source) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VsstbOp : PTO_Op<"vsstb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) + }]; +} + +def PTO_VstasOp : PTO_Op<"vstas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination, + I32:$offset + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offset attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VstarOp : PTO_Op<"vstar", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination attr-dict `:` type($value) `,` type($destination) + }]; +} + +// NOTE: Unvalidated stateful store-family abstractions. These preserve +// align/base/offset update results explicitly in SSA form instead of relying on +// implicit CCE reference updates. +// Keep `base/base_out` pointer-only (`PTO_BufferType`): memref semantics for +// stateful post-update addresses are intentionally out of scope in this change. +def PTO_PstuOp : PTO_Op<"pstu", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_MaskTypeConstraint:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out, PTO_BufferType:$base_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) `,` type($base_out) + }]; +} + +def PTO_VstusOp : PTO_Op<"vstus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + I32:$offset, + PTO_VectorType:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $offset `,` $value `,` $base attr-dict `:` type($align_in) `,` type($offset) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +def PTO_VsturOp : PTO_Op<"vstur", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_VectorType:$value, + PTO_BufferType:$base, + StrAttr:$mode + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base `,` $mode attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOOPS diff --git a/include/PTO/IR/VPTOTypeDefs.td b/include/PTO/IR/VPTOTypeDefs.td new file mode 100644 index 000000000..04e8ac583 --- /dev/null +++ b/include/PTO/IR/VPTOTypeDefs.td @@ -0,0 +1,53 @@ +//===- VPTOTypeDefs.td ---------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS +#define MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +def VRegType : TypeDef { + let mnemonic = "vreg"; + let summary = "A 256-byte PTO low-level vector"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def MaskType : TypeDef { + let mnemonic = "mask"; + let summary = "A PTO low-level predicate/mask register"; + + let parameters = (ins + StringRefParameter<"mask granularity view">:$granularity + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + }]; +} + +def AlignType : TypeDef { + let mnemonic = "align"; + let summary = "A PTO low-level vector_align carrier"; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS diff --git a/include/PTO/Transforms/HIVMIntrinsicNaming.h b/include/PTO/Transforms/HIVMIntrinsicNaming.h new file mode 100644 index 000000000..7ba956168 --- /dev/null +++ b/include/PTO/Transforms/HIVMIntrinsicNaming.h @@ -0,0 +1,60 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H +#define MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H + +#include +#include + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::pto { + +struct NamingInputs { + std::string sourceOpName; + std::string family; + std::string vectorShape; + std::string elementType; + std::vector usedFields; + std::vector missingFields; +}; + +struct UnresolvedEmissionRecord { + std::string sourceOpName; + std::string placeholderName; + std::string candidateName; + std::vector usedFields; + std::vector missingFields; + std::string resultTypeFragment; + std::string location; +}; + +struct IntrinsicSelection { + bool resolved = false; + std::string sourceOpName; + std::string calleeName; + std::string placeholderName; + std::string candidateName; + std::vector usedFields; + std::vector missingFields; + std::string resultTypeFragment; + std::string location; + + std::string getEmittedCallee() const { + return resolved ? calleeName : placeholderName; + } + + UnresolvedEmissionRecord asUnresolvedRecord() const { + return UnresolvedEmissionRecord{sourceOpName, placeholderName, candidateName, + usedFields, missingFields, resultTypeFragment, + location}; + } +}; + +FailureOr selectIntrinsic(Operation *op); +FailureOr selectLoadIntrinsic(Operation *op); +FailureOr selectUnaryIntrinsic(Operation *op); +FailureOr selectStoreIntrinsic(Operation *op); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_HIVMINTRINSICNAMING_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index cafdb784c..b19a90947 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -61,6 +61,12 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createInferPTOLayoutPass(); +std::unique_ptr createPTOVPTOExpandBridgeOpsPass(); +std::unique_ptr createPTOVPTOPtrBoundaryPass(); +std::unique_ptr createPTOValidateVPTOIRPass(); +std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +std::unique_ptr createLowerPTOToVPTOPass(); +std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy); // Declare register function void registerPTOPasses(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 37979bf21..504954243 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -15,6 +15,9 @@ // //===----------------------------------------------------------------------===// +// The VPTO backend is emitted from tools/ptoas rather than a TableGen pass; +// these registrations continue to describe the shared pre-backend pipeline. + #ifndef MLIR_DIALECT_PTO_PASSES #define MLIR_DIALECT_PTO_PASSES @@ -178,4 +181,99 @@ def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> { ]; } +def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { + let summary = + "Validate authoring-stage VPTO legality before ptr-boundary canonicalization"; + let description = [{ + Runs the authoring-stage VPTO legality verifier on post-mainline VPTO IR. + This stage keeps the memref-first authoring surface legal, while checking + the shared structural contracts that must hold before emission-boundary + canonicalization. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVPTOEmissionIR + : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { + let summary = + "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + let description = [{ + Runs the emission-stage VPTO legality verifier on ptr-form VPTO IR after + `PTOVPTOPtrBoundary`. This stage re-checks the shared authoring contracts + and confirms the final emission surface no longer carries memref boundary + state or residual bridge scaffold. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOEmissionIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOVPTOExpandBridgeOps + : Pass<"pto-vpto-expand-bridge-ops", "func::FuncOp"> { + let summary = + "Expand temporary VPTO bridge ops back to emission-ready VPTO IR"; + let description = [{ + Low-level fusion may keep temporary bridge ops in VPTO IR so legality and + alias analysis can still see memref-form operands. This pass expands those + bridge ops back to the existing emission-ready pointer-level VPTO forms + before backend emission. + }]; + let constructor = "mlir::pto::createPTOVPTOExpandBridgeOpsPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::pto::PTODialect"]; +} + +def PTOVPTOPtrBoundary + : Pass<"pto-vpto-ptr-boundary", "ModuleOp"> { + let summary = + "Canonicalize the final VPTO emission boundary from memref-first IR to ptr ABI"; + let description = [{ + Runs the final emission-boundary ptr canonicalization after the backend + mainline has finished its memref-first optimization pipeline. This pass + rewrites eligible memref function arguments to same-space `!pto.ptr`, + rejects memref function results, canonicalizes supported body-level VPTO + buffer-like ops to ptr-form, and drops dead boundary/view scaffold such as + trivial `pto.castptr`, `pto.bind_tile`, `memref.subview`, + `memref.reinterpret_cast`, and `memref.memory_space_cast` once they become + unused. + }]; + let constructor = "mlir::pto::createPTOVPTOPtrBoundaryPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect"]; +} + +def PTOToVPTO : Pass<"pto-to-vpto", "ModuleOp"> { + let summary = "Lower PTO tile ops to VPTO backend ops"; + let description = [{ + Lowers PTO tile ops to VPTO backend ops. For already-planned fusion groups, + the pass rewrites the `pto.fusion_region` body in place and preserves the + wrapper until explicit flatten. Residual non-fused PTO ops may continue to + be lowered directly in their parent block and are not wrapped into + synthetic `pto.fusion_region` containers solely for backend lowering. + }]; + let constructor = "mlir::pto::createLowerPTOToVPTOPass()"; + let options = [ + Option<"loweringStrategy", "pto-lowering-strategy", "std::string", + "\"post-update\"", + "vector lowering strategy: post-update or no-post-update"> + ]; + let dependentDialects = [ + "pto::PTODialect", + "func::FuncDialect", + "arith::ArithDialect", + "memref::MemRefDialect", + "scf::SCFDialect" + ]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/VPTOLLVMEmitter.h b/include/PTO/Transforms/VPTOLLVMEmitter.h new file mode 100644 index 000000000..dc56f64b2 --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitter.h @@ -0,0 +1,43 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H + +#include + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class ModuleOp; +} + +namespace llvm { +class raw_ostream; +} + +namespace mlir::pto { + +struct VPTOEmissionOptions { + bool dumpVPTOIR = false; + bool printIntrinsicSelections = false; + bool allowUnresolved = true; + std::string unresolvedReportPath; + std::string targetTriple; + std::string march; + std::string aicoreArch; + std::string defaultTargetCPU; + std::string defaultTargetFeatures; +}; + +LogicalResult +translateVPTOModuleToLLVMText(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); + +LogicalResult +translateVPTOModuleToLLVMBitcode(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H diff --git a/include/PTO/Transforms/VPTOLLVMEmitterHelper.h b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h new file mode 100644 index 000000000..555bbe274 --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h @@ -0,0 +1,6 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H diff --git a/include/PTO/Transforms/VPTOLowering.h b/include/PTO/Transforms/VPTOLowering.h new file mode 100644 index 000000000..17730ab4e --- /dev/null +++ b/include/PTO/Transforms/VPTOLowering.h @@ -0,0 +1,241 @@ +//===- VPTOLowering.h - PTO to VPTO lowering contracts ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ + +#include "PTO/IR/PTO.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { + +enum class VPTOTileDomain { + Vec, + Acc, + Mat, +}; + +enum class VPTOLoweringStrategy { + PostUpdate, + NoPostUpdate, +}; + +struct VPTOPartitionTrace { + SmallVector offsets; + SmallVector sizes; + bool hasDynamicOffsets = false; + bool hasDynamicSizes = false; +}; + +struct VPTOLoopProgramming { + int64_t loop2 = 1; + int64_t loop1 = 1; + int64_t srcLoop2Stride = 1; + int64_t srcLoop1Stride = 1; + int64_t dstLoop2Stride = 1; + int64_t dstLoop1Stride = 1; +}; + +enum class VPTOLoopScopeKind { + None, + AIVVectorScope, +}; + +struct VPTOLoopScopeContract { + VPTOLoopScopeKind kind = VPTOLoopScopeKind::None; + StringRef loweredAttr = "llvm.loop.aivector_scope"; + int64_t loopDepth = 0; +}; + +struct VPTOLoadContract { + StringRef sourceLayout; + SmallVector sourceShape; + SmallVector sourceStrides; + StringRef tileLayout; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + StringRef padMode; + Value padValue; + Value leftPaddingNum; + Value rightPaddingNum; + bool initOutBuffer = false; + Value initCondition; + VPTOPartitionTrace trace; +}; + +struct VPTOUnaryContract { + StringRef family; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + StringRef tileLayout; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Type elementType; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOBinaryContract { + StringRef family; + VPTOTileDomain tileDomain = VPTOTileDomain::Vec; + StringRef tileLayout; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Type elementType; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOStoreContract { + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + StringRef destinationLayout; + SmallVector destinationShape; + SmallVector destinationStrides; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + VPTOPartitionTrace trace; +}; + +void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder); +void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder); +void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder); +void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder); +void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder); +void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder); +FailureOr +createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, + PatternRewriter &rewriter); +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc); + +LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter); +LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter); +LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter); +LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter); +LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter); +LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter); +LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter); +LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter); +LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter); +LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter); +LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter); +LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter); +LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter); +LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter); +LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter); +LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter); +LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter); +LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter); +LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter); +LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter); +LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy); +LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter); +LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter); +LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter); +LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter); +LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter); +LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter); +LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter); +LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter); +LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter); +LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter); +LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter); +LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter); +LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter); +LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter); +LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter); +LogicalResult convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS = nullptr); + +} // namespace pto +} // namespace mlir + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index 787749157..02f122cb6 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -26,7 +26,10 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PTO, pto); // ---- !pto.ptr ---- bool mlirPTOTypeIsAPtrType(MlirType type); MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType); +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace); MlirType mlirPTOPtrTypeGetElementType(MlirType type); +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type); // ---- !pto.async_session / !pto.async_event ---- bool mlirPTOTypeIsAAsyncSessionType(MlirType type); diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index e9e32ba98..6fe560a06 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries(_pto PRIVATE MLIRSupport MLIRArithDialect MLIRMemRefDialect + MLIRSCFDialect MLIRDestinationStyleOpInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces @@ -47,6 +48,7 @@ target_link_libraries(_pto PRIVATE MLIRLoopLikeInterface MLIRViewLikeInterface MLIRFunctionInterfaces + MLIRLLVMDialect ) # 关键:放到 mlir/_mlir_libs 下(匹配 MLIR dialect python 的 import 习惯) diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index c8dff3109..418358947 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -520,20 +520,34 @@ PYBIND11_MODULE(_pto, m) { [](MlirType type) -> bool { return mlirPTOTypeIsAPtrType(type); }) .def_classmethod( "get", - [](py::object cls, MlirType elementType, + [](py::object cls, MlirType elementType, py::object memorySpace, MlirContext context) -> py::object { MlirContext ctx = context; if (!ctx.ptr) ctx = mlirTypeGetContext(elementType); - MlirType t = mlirPTOPtrTypeGet(ctx, elementType); + MlirType t = {nullptr}; + if (memorySpace.is_none()) { + t = mlirPTOPtrTypeGet(ctx, elementType); + } else { + MlirAttribute memorySpaceAttr = + py::cast(memorySpace); + t = mlirPTOPtrTypeGetWithMemorySpace(ctx, elementType, + memorySpaceAttr); + } return cls.attr("__call__")(t); }, py::arg("cls"), py::arg("element_type"), + py::arg("memory_space") = py::none(), py::arg("context") = py::none()) .def_property_readonly( "element_type", [](MlirType self) -> MlirType { return mlirPTOPtrTypeGetElementType(self); + }) + .def_property_readonly( + "memory_space", + [](MlirType self) -> MlirAttribute { + return mlirPTOPtrTypeGetMemorySpace(self); }); mlir_type_subclass( diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 162519fa6..a3659b265 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -60,6 +60,14 @@ MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType) { return wrap(mlir::pto::PtrType::get(c, elem)); } +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace) { + auto c = unwrap(ctx); + auto elem = unwrap(elementType); + auto space = mlir::cast(unwrap(memorySpace)); + return wrap(mlir::pto::PtrType::get(c, elem, space)); +} + MlirType mlirPTOPtrTypeGetElementType(MlirType type) { auto t = cast(unwrap(type));; return wrap(t.getElementType()); @@ -81,6 +89,11 @@ MlirType mlirPTOAsyncEventTypeGet(MlirContext ctx) { return wrap(mlir::pto::AsyncEventType::get(unwrap(ctx))); } +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type) { + auto t = cast(unwrap(type)); + return wrap(t.getMemorySpace()); +} + bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr) { return mlir::isa(unwrap(attr)); } diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index b73e35090..f0c37824d 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -14,6 +14,7 @@ # [关键] 库名重命名为 PTOIR,避免与 LLVM 里的 PTODialect/MLIRPTODialect 冲突 add_mlir_dialect_library(PTOIR PTO.cpp + VPTO.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp @@ -28,6 +29,7 @@ add_mlir_dialect_library(PTOIR MLIRIR MLIRFuncDialect MLIRMemRefDialect + MLIRSCFDialect MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRSideEffectInterfaces diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..4ddf5afc8 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -150,6 +150,11 @@ static std::optional getConstantIntegerValue(Value value); static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, Type src1Ty, Type dstTy); static bool isRowMajorTileBuf(Type ty); +static ParseResult parseQuotedPipeToken(OpAsmParser &parser, PipeAttr &attr); +static ParseResult parseQuotedEventToken(OpAsmParser &parser, EventAttr &attr); +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr); +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr); +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr); #define GET_ENUM_CLASSES #include "PTO/IR/PTOEnums.cpp.inc" @@ -305,6 +310,24 @@ static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, return verifyA5(); } } +static std::optional parsePtrAddressSpaceKeyword(StringRef keyword) { + return llvm::StringSwitch>(keyword) + .Case("gm", pto::AddressSpace::GM) + .Case("ub", pto::AddressSpace::VEC) + .Default(std::nullopt); +} + +static StringRef printPtrAddressSpaceKeyword(pto::AddressSpace space) { + switch (space) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return "gm"; + case pto::AddressSpace::VEC: + return "ub"; + default: + return {}; + } +} static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { mlir::Type ty; @@ -359,17 +382,22 @@ static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { mlir::Type elem; if (failed(parser.parseType(elem))) return mlir::Type(); + auto memorySpace = pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); if (succeeded(parser.parseOptionalComma())) { - // ptr no longer accepts an address space; consume the attr for recovery. - mlir::Attribute memorySpace; - (void)parser.parseAttribute(memorySpace); - parser.emitError(parser.getCurrentLocation(), - "!pto.ptr no longer accepts address space; use !pto.ptr"); - return mlir::Type(); + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return mlir::Type(); + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be `gm` or `ub`"); + return mlir::Type(); + } + memorySpace = pto::AddressSpaceAttr::get(ctx, *parsed); } if (failed(parser.parseGreater())) return mlir::Type(); - return mlir::pto::PtrType::get(ctx, elem); + return mlir::pto::PtrType::get(ctx, elem, memorySpace); } if (head == "pto.tensor_view") { @@ -395,6 +423,40 @@ void TensorViewType::print(::mlir::AsmPrinter &printer) const { printShapeAndElem(printer, getShape(), getElementType()); } +mlir::Type PtrType::parse(::mlir::AsmParser &parser) { + Type elementType; + if (failed(parser.parseLess()) || failed(parser.parseType(elementType))) + return {}; + + auto memorySpace = + pto::AddressSpaceAttr::get(parser.getContext(), pto::AddressSpace::GM); + if (succeeded(parser.parseOptionalComma())) { + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return {}; + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be `gm` or `ub`"); + return {}; + } + memorySpace = pto::AddressSpaceAttr::get(parser.getContext(), *parsed); + } + + if (failed(parser.parseGreater())) + return {}; + return PtrType::get(parser.getContext(), elementType, memorySpace); +} + +void PtrType::print(::mlir::AsmPrinter &printer) const { + printer << "<" << getElementType(); + StringRef memorySpaceKeyword = + printPtrAddressSpaceKeyword(getMemorySpace().getAddressSpace()); + if (!memorySpaceKeyword.empty()) + printer << ", " << memorySpaceKeyword; + printer << ">"; +} + //===----------------------------------------------------------------------===// // pto.tdivs custom asm to support both: // pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) @@ -1182,6 +1244,43 @@ LogicalResult mlir::pto::AddPtrOp::verify() { return success(); } +LogicalResult mlir::pto::CastPtrOp::verify() { + Type inputType = getInput().getType(); + Type resultType = getResult().getType(); + + auto inputPtrType = dyn_cast(inputType); + auto resultPtrType = dyn_cast(resultType); + auto inputMemRefType = dyn_cast(inputType); + bool inputIsInteger = isa(inputType); + bool resultIsInteger = isa(resultType); + + if (!inputPtrType && !inputMemRefType && !inputIsInteger) + return emitOpError("input must be an integer, memref, or !pto.ptr<...>"); + if (!resultPtrType && !resultIsInteger) + return emitOpError("result must be an integer or !pto.ptr<...>"); + + if (inputIsInteger && resultIsInteger) + return emitOpError("integer-to-integer cast is not a ptr cast"); + + if (inputMemRefType && resultIsInteger) + return emitOpError("memref-to-integer cast is unsupported"); + + if (inputMemRefType && resultPtrType) { + auto memrefSpace = dyn_cast_or_null( + inputMemRefType.getMemorySpace()); + auto resultSpace = resultPtrType.getMemorySpace(); + if (memrefSpace && memrefSpace != resultSpace) + return emitOpError("memref-to-ptr cast must stay within the same PTO memory space"); + } + + if (inputPtrType && resultPtrType && + inputPtrType.getMemorySpace() != resultPtrType.getMemorySpace()) { + return emitOpError("ptr-to-ptr cast must stay within the same PTO memory space"); + } + + return success(); +} + @@ -1204,6 +1303,8 @@ void PTODialect::initialize() { AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); auto memRefType = dyn_cast(type); assert(memRefType && "input type must be a memref type"); auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); @@ -1213,7 +1314,7 @@ AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { bool mlir::pto::isScalarPtrOrMemRef(Type type) { if (auto pty = dyn_cast(type)) - return true; + return static_cast(pty); if (auto memTy = dyn_cast(type)) return isGmAddressSpaceAttr(memTy.getMemorySpace()); return false; @@ -5039,14 +5140,19 @@ static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, if (!opTypeAttr) return op->emitOpError("expects 'op_type' attribute"); - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) { - auto diag = - op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); - diag << opTypeAttr; - return failure(); + pto::PIPE pipe = pto::PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) { + auto diag = op->emitOpError( + "expects 'op_type' to be pipe_event_type/sync_op_type/pipe, got "); + diag << opTypeAttr; + return failure(); + } + pipe = mapSyncOpTypeToPipe(*opTypeOr); } - pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); if (!isConcreteSyncPipe(pipe)) return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); @@ -5074,6 +5180,282 @@ LogicalResult RlsBufOp::verify() { return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), getModeAttr()); } + +static ParseResult parseQuotedPipeToken(OpAsmParser &parser, PipeAttr &attr) { + std::string pipeName; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseString(&pipeName))) + return failure(); + auto pipe = symbolizePIPE(pipeName); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << pipeName; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); +} + +static ParseResult parseQuotedEventToken(OpAsmParser &parser, EventAttr &attr) { + std::string eventName; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseString(&eventName))) + return failure(); + auto event = symbolizeEVENT(eventName); + if (!event) + return parser.emitError(loc) << "invalid event token: " << eventName; + attr = EventAttr::get(parser.getContext(), *event); + return success(); +} + +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto pipe = symbolizePIPE(token); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << token; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto pipe = symbolizePIPE(keyword); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << keyword; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto pipeAttr = dyn_cast(parsed); + if (!pipeAttr) + return parser.emitError(loc, "expected pipe attribute"); + attr = pipeAttr; + return success(); +} + +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto event = symbolizeEVENT(token); + if (!event) + return parser.emitError(loc) << "invalid event token: " << token; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto event = symbolizeEVENT(keyword); + if (!event) + return parser.emitError(loc) << "invalid event token: " << keyword; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto eventAttr = dyn_cast(parsed); + if (!eventAttr) + return parser.emitError(loc, "expected event attribute"); + attr = eventAttr; + return success(); +} + +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr) { + auto loc = parser.getCurrentLocation(); + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) + return parser.emitError(loc, "expected 32-bit integer literal"); + attr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), value); + return success(); +} + +static void printLegacySyncTriplet(OpAsmPrinter &p, PipeAttr srcPipe, + PipeAttr dstPipe, EventAttr eventId, + ArrayRef attrs) { + p << "[\"" << stringifyPIPE(srcPipe.getPipe()) << "\", \"" + << stringifyPIPE(dstPipe.getPipe()) << "\", \"" + << stringifyEVENT(eventId.getEvent()) << "\"]"; + p.printOptionalAttrDict(attrs, {"src_pipe", "dst_pipe", "event_id"}); +} + +ParseResult SetFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void SetFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +ParseResult WaitFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void WaitFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +static ParseResult parseLegacyOrAttrOpType(OpAsmParser &parser, + Attribute &opTypeAttr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + if (auto pipe = symbolizePIPE(token)) { + opTypeAttr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + if (auto opType = symbolizeSyncOpType(token)) { + opTypeAttr = PipeEventTypeAttr::get(parser.getContext(), *opType); + return success(); + } + return parser.emitError(loc) << "invalid get_buf/rls_buf token: " << token; + } + + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseAttribute(opTypeAttr))) + return failure(); + return success(); + } + + if (failed(parser.parseAttribute(opTypeAttr))) + return failure(); + return success(); +} + +static ParseResult parseBufSyncOp(OpAsmParser &parser, OperationState &result) { + Attribute opTypeAttr; + IntegerAttr bufIdAttr; + IntegerAttr modeAttr; + + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + if (auto pipe = symbolizePIPE(token)) + opTypeAttr = PipeAttr::get(parser.getContext(), *pipe); + else if (auto opType = symbolizeSyncOpType(token)) + opTypeAttr = PipeEventTypeAttr::get(parser.getContext(), *opType); + else + return parser.emitError(loc) << "invalid get_buf/rls_buf token: " << token; + + if (parser.parseComma() || parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + } else if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(opTypeAttr) || parser.parseComma() || + parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + if (parser.parseRSquare()) + return failure(); + } else { + return parser.emitError(loc, "expected string pipe/op_type or '['"); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("op_type", opTypeAttr); + result.addAttribute("buf_id", bufIdAttr); + result.addAttribute("mode", modeAttr); + return success(); +} + +static void printBufSyncOp(OpAsmPrinter &p, Attribute opTypeAttr, + IntegerAttr bufIdAttr, IntegerAttr modeAttr, + ArrayRef attrs) { + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + p << " \"" << stringifyPIPE(pipeAttr.getPipe()) << "\", " + << bufIdAttr.getInt() << ", " << modeAttr.getInt(); + } else if (auto pipeEventType = dyn_cast(opTypeAttr)) { + auto pipe = mapSyncOpTypeToPipe(pipeEventType.getOpType()); + if (isConcreteSyncPipe(pipe)) { + p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() + << ", " << modeAttr.getInt(); + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + } else if (auto syncOpType = dyn_cast(opTypeAttr)) { + auto pipe = mapSyncOpTypeToPipe(syncOpType.getOpType()); + if (isConcreteSyncPipe(pipe)) { + p << " \"" << stringifyPIPE(pipe) << "\", " << bufIdAttr.getInt() + << ", " << modeAttr.getInt(); + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + p.printOptionalAttrDict(attrs, {"op_type", "buf_id", "mode"}); +} + +ParseResult GetBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void GetBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} + +ParseResult RlsBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void RlsBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} // ---- TOp ---- LogicalResult TGemvBiasOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp new file mode 100644 index 000000000..02351975d --- /dev/null +++ b/lib/PTO/IR/VPTO.cpp @@ -0,0 +1,2877 @@ +//===- VPTO.cpp - VPTO dialect -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +static llvm::cl::opt disableVPTOAlignChainVerification( + "vpto-disable-align-chain-verification", + llvm::cl::desc("Disable !pto.align linear-chain verifier checks"), + llvm::cl::init(false), llvm::cl::Hidden); + +static std::string formatVRegType(int64_t elementCount, Type elementType) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.vreg<" << elementCount << "x" << elementType << ">"; + return storage; +} + +static std::string formatMaskType(StringRef granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" << granularity << ">"; + return storage; +} + +static LogicalResult verifyVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + auto vecType = dyn_cast(type); + if (!vecType) + return op->emitOpError() << roleDescription << " must be !pto.vreg<...>"; + + return VRegType::verify( + [&]() { return op->emitOpError() << roleDescription << " "; }, + vecType.getElementCount(), vecType.getElementType()); +} + +static LogicalResult verifyMaskTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + return success(); +} + +static LogicalResult verifyMaskTypeWithGranularityLike(Operation *op, Type type, + StringRef roleDescription, + StringRef granularity) { + auto maskType = dyn_cast(type); + if (!maskType) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + if (maskType.getGranularity() != granularity) { + return op->emitOpError() + << roleDescription << " must be " << formatMaskType(granularity); + } + return success(); +} + +static LogicalResult verifyEnclosingLoopLike(Operation *op, + StringRef opNameForDiag) { + if (!op->getParentOfType()) { + return op->emitOpError() + << "requires enclosing loop structure for " << opNameForDiag + << " lowering"; + } + return success(); +} + +static LogicalResult verifyNotNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || + op->getParentOfType()) { + return op->emitOpError() + << "must not be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is a UB helper op rather than a vecscope op"; + } + return success(); +} + +static LogicalResult verifyNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || op->getParentOfType()) + return success(); + return op->emitOpError() + << "must be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is part of the vecscope control sequence"; +} + +static LogicalResult verifyAlignTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.align"; + return success(); +} + +static bool isSupportedVdupPosition(std::optional position) { + return !position || *position == "LOWEST" || *position == "HIGHEST"; +} + +static std::optional getVdupMaskGranularity(Type elementType) { + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); + case 32: + return StringRef("b32"); + default: + return std::nullopt; + } + } + if (elementType.isF16() || elementType.isBF16()) + return StringRef("b16"); + if (elementType.isF32()) + return StringRef("b32"); + return std::nullopt; +} + +static bool isSupportedVtrcRoundMode(StringRef mode) { + return mode == "R" || mode == "A" || mode == "F" || mode == "C" || + mode == "Z"; +} + +static bool isStoreAlignProducer(Operation *op) { + return isa(op); +} + +static bool isStoreAlignSink(Operation *op) { + return isa(op); +} + +static bool isLoadAlignProducer(Operation *op) { + return isa(op); +} + +static bool isValueOwnedByRegion(Value value, Region *region) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getParentRegion() == region; + if (Operation *def = value.getDefiningOp()) + return def->getParentRegion() == region; + return false; +} + +static FailureOr resolveStoreAlignRoot(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) { + return failure(); + } + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isStoreAlignProducer(def)) + return current; + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + } + + return failure(); + } +} + +static LogicalResult verifyStoreAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static LogicalResult verifyStoreAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + SmallVector terminalUsers; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (isStoreAlignSink(owner)) { + terminalUsers.push_back(owner); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() + terminalUsers.size() > 1) { + return user->emitOpError() + << "!pto.align value must form a single linear store-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyStoreAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyStoreAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveStoreAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isStoreAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + + return verifyStoreAlignLinearUses(*root, user); +} + +static FailureOr resolveLoadAlignRoot(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) + return failure(); + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isLoadAlignProducer(def)) + return current; + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + } + + return failure(); + } +} + +static LogicalResult verifyLoadAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static LogicalResult verifyLoadAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getUpdatedAlign()); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) { + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + } + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) { + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + } + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) { + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + } + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) { + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + } + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() > 1) { + return user->emitOpError() + << "!pto.align value must form a single linear load-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyLoadAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyLoadAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveLoadAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isLoadAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + + return verifyLoadAlignLinearUses(*root, user); +} + +static bool isSupportedPredicatePattern(StringRef pattern) { + return pattern == "PAT_ALL" || pattern == "PAT_VL1" || pattern == "PAT_VL2" || + pattern == "PAT_VL3" || pattern == "PAT_VL4" || pattern == "PAT_VL8" || + pattern == "PAT_VL16" || pattern == "PAT_VL32" || + pattern == "PAT_VL64" || pattern == "PAT_VL128" || + pattern == "PAT_M3" || pattern == "PAT_M4" || pattern == "PAT_H" || + pattern == "PAT_Q" || pattern == "PAT_ALLF"; +} + +static bool isSupportedPredicateLoadDist(StringRef dist) { + return dist == "NORM" || dist == "US" || dist == "DS"; +} + +static bool isSupportedPredicateStoreDist(StringRef dist) { + return dist == "NORM" || dist == "PK"; +} + +static bool isSupportedStrideToken(StringRef stride) { + return stride == "STRIDE_S3_B16" || stride == "STRIDE_S4_B64" || + stride == "STRIDE_S8_B32" || stride == "STRIDE_S2_B64" || + stride == "STRIDE_VSST_S8_B16"; +} + +static bool isSupportedPartToken(StringRef part) { + return part == "LOWER" || part == "HIGHER"; +} + +static bool isSupportedSprToken(StringRef spr) { return spr == "AR"; } + +static std::optional normalizeRoundModeToken(StringRef token) { + if (token == "R" || token == "ROUND_R") + return StringRef("R"); + if (token == "A" || token == "ROUND_A") + return StringRef("A"); + if (token == "F" || token == "ROUND_F") + return StringRef("F"); + if (token == "C" || token == "ROUND_C") + return StringRef("C"); + if (token == "Z" || token == "ROUND_Z") + return StringRef("Z"); + if (token == "O" || token == "ROUND_O") + return StringRef("O"); + return std::nullopt; +} + +static std::optional normalizeSaturationToken(StringRef token) { + if (token == "SAT" || token == "RS_ENABLE") + return StringRef("SAT"); + if (token == "NOSAT" || token == "RS_DISABLE") + return StringRef("NOSAT"); + return std::nullopt; +} + +static std::optional normalizeEvenOddPartToken(StringRef token) { + if (token == "EVEN" || token == "PART_EVEN") + return StringRef("EVEN"); + if (token == "ODD" || token == "PART_ODD") + return StringRef("ODD"); + return std::nullopt; +} + +namespace { + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + bool requiresRnd; + bool requiresSat; + bool requiresPart; +}; + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional getVcvtElemBitWidth(VcvtElemKind kind) { + switch (kind) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::U16: + return 16; + case VcvtElemKind::F32: + case VcvtElemKind::S32: + case VcvtElemKind::U32: + return 32; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return 8; + case VcvtElemKind::S64: + return 64; + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::U16: + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::S16: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::F32: + case VcvtElemKind::U32: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +} // namespace + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static bool matchesWidthFamily(StringRef dist, unsigned width, + ArrayRef allowedWidths) { + return llvm::is_contained(allowedWidths, width); +} + +static bool isSupportedVldx2DistToken(StringRef dist) { + return dist == "BDINTLV" || dist == "DINTLV"; +} + +static bool isSupportedVldsDistToken(StringRef dist) { + return dist == "NORM" || dist == "BRC" || dist == "US" || dist == "DS" || + dist == "UNPK" || dist == "BRC_BLK" || dist == "E2B" || + dist == "UNPK4" || dist == "SPLT4CHN" || dist == "SPLT2CHN"; +} + +static bool isSupportedVstsDistToken(StringRef dist) { + return dist == "NORM" || dist == "1PT" || dist == "PK" || + dist == "PK4" || dist == "MRG4CHN" || dist == "MRG2CHN"; +} + +static bool isSupportedVstsx2DistToken(StringRef dist) { + return dist == "INTLV"; +} + +static LogicalResult verifyVldsDistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError("requires load element type with a concrete bit width"); + + if (dist == "NORM" || dist == "BRC_BLK") + return success(); + if (dist == "BRC") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist BRC only supports 8/16/32-bit elements"); + if (dist == "US") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist US only supports 8/16-bit elements"); + if (dist == "DS") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist DS only supports 8/16-bit elements"); + if (dist == "UNPK") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist UNPK only supports 8/16/32-bit elements"); + if (dist == "E2B") + return matchesWidthFamily(dist, *width, {16, 32}) + ? success() + : op->emitOpError("dist E2B only supports 16/32-bit elements"); + if (dist == "UNPK4") + return *width == 8 + ? success() + : op->emitOpError("dist UNPK4 only supports 8-bit elements"); + if (dist == "SPLT4CHN") + return *width == 8 + ? success() + : op->emitOpError("dist SPLT4CHN only supports 8-bit elements"); + if (dist == "SPLT2CHN") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist SPLT2CHN only supports 8/16-bit elements"); + + return op->emitOpError("requires a supported load distribution token"); +} + +static LogicalResult verifyVldsx2DistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires x2 load element type with a concrete bit width"); + if (dist == "BDINTLV") + return success(); + if (dist == "DINTLV") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist DINTLV only supports 8/16/32-bit elements"); + return op->emitOpError("requires a supported x2 load distribution token"); +} + +static LogicalResult verifyVstsDistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires store element type with a concrete bit width"); + + if (dist == "NORM") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist NORM only supports 8/16/32-bit elements"); + if (dist == "1PT") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist 1PT only supports 8/16/32-bit elements"); + if (dist == "PK") + return matchesWidthFamily(dist, *width, {16, 32, 64}) + ? success() + : op->emitOpError("dist PK only supports 16/32/64-bit elements"); + if (dist == "PK4") + return *width == 32 + ? success() + : op->emitOpError("dist PK4 only supports 32-bit elements"); + if (dist == "MRG4CHN") + return *width == 8 + ? success() + : op->emitOpError("dist MRG4CHN only supports 8-bit elements"); + if (dist == "MRG2CHN") + return matchesWidthFamily(dist, *width, {8, 16}) + ? success() + : op->emitOpError("dist MRG2CHN only supports 8/16-bit elements"); + + return op->emitOpError("requires a supported store distribution token"); +} + +static LogicalResult verifyVstsx2DistWidth(Operation *op, StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return op->emitOpError( + "requires x2 store element type with a concrete bit width"); + if (dist == "INTLV") + return matchesWidthFamily(dist, *width, {8, 16, 32}) + ? success() + : op->emitOpError("dist INTLV only supports 8/16/32-bit elements"); + return op->emitOpError("requires a supported x2 store distribution token"); +} + +static bool isSupportedPostMode(StringRef mode) { + return mode == "NO_POST_UPDATE" || mode == "POST_UPDATE"; +} + +static std::optional getOptionalPostModeAttr(Operation *op) { + if (auto mode = op->getAttrOfType("mode")) + return mode.getValue(); + return std::nullopt; +} + +static unsigned getIntOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return 0; +} + +static bool isIntegerOrFloatLike(Type type) { + return isa(type) || isa(type); +} + +static std::optional getVRegStorageBitWidth(Type type) { + auto vecType = dyn_cast(type); + if (!vecType) + return std::nullopt; + unsigned elemWidth = getIntOrFloatBitWidth(vecType.getElementType()); + if (!elemWidth) + return std::nullopt; + return vecType.getElementCount() * static_cast(elemWidth); +} + +static LogicalResult verifyIntegerVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (failed(verifyVRegTypeLike(op, type, roleDescription))) + return failure(); + auto vecType = cast(type); + if (!isa(vecType.getElementType())) + return op->emitOpError() + << roleDescription << " must use integer vector element type"; + return success(); +} + +enum class MemoryRole { + Unknown, + GM, + UB, + Other, +}; + +static MemoryRole classifyMemoryRole(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + return MemoryRole::Other; + } + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return MemoryRole::Unknown; + + if (auto addrSpace = dyn_cast(memorySpace)) { + switch (addrSpace.getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + if (auto intAttr = dyn_cast(memorySpace)) { + switch (intAttr.getInt()) { + case static_cast(pto::AddressSpace::GM): + case static_cast(pto::AddressSpace::Zero): + return MemoryRole::GM; + case static_cast(pto::AddressSpace::VEC): + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + return MemoryRole::Other; +} + +static bool isBufferLike(Type type) { + return isa(type); +} + +static int64_t getPtrElementByteSize(Type type) { + auto ptrType = dyn_cast(type); + if (!ptrType) + return 0; + + Type elementType = ptrType.getElementType(); + if (auto floatType = dyn_cast(elementType)) + return (floatType.getWidth() + 7) / 8; + if (auto intType = dyn_cast(elementType)) + return (intType.getWidth() + 7) / 8; + return 0; +} + +template +static LogicalResult verifyCopyGmToUbufOp(CopyOp op, bool expectSourceGM) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto destinationType = dyn_cast(op.getDestination().getType()); + if (!sourceType || !destinationType) + return op.emitOpError("requires typed !pto.ptr source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getPtrElementByteSize(sourceType); + int64_t destinationElemBytes = getPtrElementByteSize(destinationType); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +template +static LogicalResult verifyCopyUbufToGmOp(CopyOp op, bool expectSourceGM) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto destinationType = dyn_cast(op.getDestination().getType()); + if (!sourceType || !destinationType) + return op.emitOpError("requires typed !pto.ptr source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getPtrElementByteSize(sourceType); + int64_t destinationElemBytes = getPtrElementByteSize(destinationType); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +Type VRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType); +} + +void VRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + printer << ">"; +} + +LogicalResult VRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType) { + if (elementCount <= 0) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected a positive element count"; + + auto intOrFloat = mlir::dyn_cast(elementType); + unsigned elementBitWidth = 0; + if (intOrFloat) { + elementBitWidth = intOrFloat.getWidth(); + } else if (auto floatType = mlir::dyn_cast(elementType)) { + elementBitWidth = floatType.getWidth(); + } else { + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected an integer or floating-point element type"; + } + + if (elementCount * static_cast(elementBitWidth) != 2048) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected exactly 256 bytes"; + + return success(); +} + +LogicalResult VecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != 0) + return emitOpError() << "expects body block to have no arguments, got " + << body.getNumArguments(); + + return success(); +} + +LogicalResult StrictVecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != getCaptures().size()) + return emitOpError() << "expects body block to have " + << getCaptures().size() + << " arguments to match explicit captures, got " + << body.getNumArguments(); + + for (auto [idx, pair] : + llvm::enumerate(llvm::zip(body.getArguments(), getCaptures()))) { + BlockArgument blockArg = std::get<0>(pair); + Value capture = std::get<1>(pair); + if (blockArg.getType() != capture.getType()) + return emitOpError() << "expects body block argument #" << idx + << " to have type " << capture.getType() + << ", got " << blockArg.getType(); + } + return success(); +} + +bool MaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || + granularity == "b32"; +} + +Type MaskType::parse(AsmParser &parser) { + auto loc = parser.getCurrentLocation(); + StringRef granularity; + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&granularity)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), granularity); +} + +void MaskType::print(AsmPrinter &printer) const { + printer << "<" << getGranularity() << ">"; +} + +LogicalResult +MaskType::verify(function_ref emitError, + StringRef granularity) { + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatMaskType(granularity) + << "' expected granularity to be one of b8, b16, b32"; + return success(); +} + +void CopyGmToUbufOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyGmToUbufOp::verify() { + return verifyCopyGmToUbufOp(*this, true); +} + +LogicalResult VbrOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + + auto resultVecType = cast(getResult().getType()); + Type elementType = getValue().getType(); + if (isa(elementType)) + return emitOpError("value must be a scalar matching the result element type"); + if (elementType != resultVecType.getElementType()) + return emitOpError("value type must match result element type"); + return success(); +} + +LogicalResult VcaddOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VcmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VcminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VciOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + if (!isa(resultType.getElementType())) + return emitOpError("result element type must be integer"); + auto indexType = dyn_cast(getIndex().getType()); + if (!indexType) + return emitOpError("index must be an integer scalar"); + if (indexType != resultType.getElementType()) + return emitOpError("index type must match result element type"); + return success(); +} + +void Vgather2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + if (!isa(offsetsType.getElementType())) + return emitOpError("offset vector must use integer element type"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + if (!getActiveLanes().getType().isIndex()) + return emitOpError("active_lanes must be index"); + return success(); +} + +LogicalResult CopyUbufToUbufOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed source and destination"); + return success(); +} + +void VgatherbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VgatherbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + if (failed(verifyMaskTypeWithGranularityLike(getOperation(), getMask().getType(), + "mask type", "b32"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +void Vgather2BcOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2BcOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +LogicalResult VbitsortOp::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource().getType()) || + !isBufferLike(getIndices().getType())) + return emitOpError("requires pointer-like destination/source/indices"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getIndices().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination/source/indices"); + if (!getRepeatTimes().getType().isIndex()) + return emitOpError("repeat_times must be index"); + if (failed(verifyNotNestedInVecScope(*this, "pto.vbitsort"))) + return failure(); + return success(); +} + +void VbitsortOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getIndicesMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vmrgsort4Op::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource0().getType()) || + !isBufferLike(getSource1().getType()) || !isBufferLike(getSource2().getType()) || + !isBufferLike(getSource3().getType())) + return emitOpError("requires pointer-like destination and sources"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource0().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource1().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource2().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource3().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination and sources"); + return success(); +} + +LogicalResult VmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +LogicalResult VminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +void VldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +template +static LogicalResult verifyVldsCommon(LoadOp op) { + if (!isBufferLike(op.getSource().getType())) + return op.emitOpError("requires a pointer-like source"); + + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + if (sourceRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed source"); + + if (op.getDistAttr()) { + StringRef dist = *op.getDist(); + if (!isSupportedVldsDistToken(dist)) + return op.emitOpError( + "supports only NORM, BRC, US, DS, UNPK, BRC_BLK, E2B, UNPK4, " + "and SPLT2CHN/SPLT4CHN load distributions"); + if (failed(verifyVldsDistWidth( + op.getOperation(), dist, + cast(op.getResult().getType()).getElementType()))) + return failure(); + } + + return success(); +} + +LogicalResult VldsOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VldsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldsPostOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (getUpdatedSource().getType() != getSource().getType()) + return emitOpError("requires updated source result to match source type"); + return success(); +} + +void VldasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldasOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyAlignTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +LogicalResult InitAlignOp::verify() { + return verifyAlignTypeLike(*this, getResult().getType(), "result type"); +} + +LogicalResult SprclrOp::verify() { + if (!isSupportedSprToken(getSpr())) + return emitOpError("requires spr to be \"AR\""); + if (failed(verifyNestedInVecScope(*this, "pto.sprclr"))) + return failure(); + return success(); +} + +void VldusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldusOp::verify() { + if (failed(verifyLoadAlignChain(getAlign(), *this, "align type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type")) || + failed(verifyAlignTypeLike(*this, getUpdatedAlign().getType(), + "updated align type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +void UvldOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult UvldOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a buffer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto sourceMemRef = dyn_cast(getSource().getType()); + if (!sourceMemRef) + return success(); + + Type sourceElementType = sourceMemRef.getElementType(); + Type vectorElementType = cast(getResult().getType()).getElementType(); + if (sourceElementType != vectorElementType) + return emitOpError( + "requires source element type to match vector element type"); + return success(); +} + +LogicalResult VdupOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + + std::optional granularity = + getVdupMaskGranularity(resultType.getElementType()); + if (!granularity) + return emitOpError("result element type must use b8, b16, or b32 mask granularity"); + if (failed(verifyMaskTypeWithGranularityLike( + getOperation(), getMask().getType(), "mask type", *granularity))) + return failure(); + + if (!isSupportedVdupPosition(getPosition())) + return emitOpError("position must be LOWEST or HIGHEST"); + + Type inputType = getInput().getType(); + if (auto inputVecType = dyn_cast(inputType)) { + if (inputVecType != resultType) + return emitOpError("vector input must match result vector type"); + return success(); + } + + if (getPosition()) + return emitOpError("position is only supported for vector input"); + + if (inputType != resultType.getElementType()) + return emitOpError("scalar input must match result element type"); + + return success(); +} + +LogicalResult PsetB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +template +static LogicalResult verifyPredicateLaneCountOp(PltOp op, + StringRef granularity) { + if (failed(verifyMaskTypeWithGranularityLike(op, op.getMask().getType(), + "mask type", granularity))) + return failure(); + Type scalarType = op.getScalar().getType(); + auto scalarIntType = dyn_cast(scalarType); + if (!scalarIntType || scalarIntType.getWidth() != 32) + return op.emitOpError("requires scalar to be i32"); + if (op.getScalarOut().getType() != scalarType) + return op.emitOpError("requires scalar_out to match scalar type"); + return success(); +} + +LogicalResult PltB8Op::verify() { return verifyPredicateLaneCountOp(*this, "b8"); } +LogicalResult PltB16Op::verify() { + return verifyPredicateLaneCountOp(*this, "b16"); +} +LogicalResult PltB32Op::verify() { + return verifyPredicateLaneCountOp(*this, "b32"); +} + +LogicalResult PpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + return success(); +} + +LogicalResult PunpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + return success(); +} + +LogicalResult PnotOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PselOp::verify() { + if (failed(verifyMaskTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyBinaryMaskOp(BinaryMaskOp op) { + if (failed(verifyMaskTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PandOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PorOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PxorOp::verify() { return verifyBinaryMaskOp(*this); } + +void PldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldsOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + if (failed(verifyEnclosingLoopLike(*this, "pto.plds"))) + return failure(); + return success(); +} + +void PldiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldiOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + if (failed(verifyEnclosingLoopLike(*this, "pto.pldi"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyElementwiseVecScalarOpLike(OpTy op) { + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return op.emitOpError("input and result must be !pto.vreg<...>"); + if (inputType != resultType) + return op.emitOpError("input and result vector types must match"); + + Type elemType = inputType.getElementType(); + Type scalarType = op.getScalar().getType(); + if (scalarType == elemType) + return success(); + + auto elemInt = dyn_cast(elemType); + auto scalarInt = dyn_cast(scalarType); + if (!elemInt || !scalarInt || elemInt.getWidth() != scalarInt.getWidth()) + return op.emitOpError("scalar type must match vector element type"); + + if (elemInt.isSigned() && (scalarInt.isSigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isUnsigned() && + (scalarInt.isUnsigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isSignless() && scalarInt.isSignless()) + return success(); + + return op.emitOpError( + "integer scalar type must match vector element width and use matching signedness or signless i"); +} + +template +static LogicalResult verifyVecScalarOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + return success(); +} + +template +static LogicalResult verifyVecScalarMaskedOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyCarryVecOp(CarryOp op) { + if (failed(verifyIntegerVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyIntegerVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyIntegerVRegTypeLike(op, op.getResult().getType(), + "result type")) || + failed(verifyMaskTypeLike(op, op.getCarry().getType(), "carry type"))) + return failure(); + + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + auto lhsElemType = cast(lhsType.getElementType()); + if (lhsType != rhsType || lhsType != resultType) + return op.emitOpError("requires lhs, rhs, and result to have matching vector types"); + if (lhsElemType.getWidth() != 32) + return op.emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +template +static LogicalResult verifyCarryVecOpWithInput(CarryWithInputOp op) { + if (failed(verifyCarryVecOp(op)) || + failed(verifyMaskTypeLike(op, op.getCarryIn().getType(), + "carry_in type"))) + return failure(); + return success(); +} + +LogicalResult VmulsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VaddsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VmaxsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VminsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VlreluOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VshlsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} +LogicalResult VshrsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} + +LogicalResult VabsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("requires matching register vector shape"); + return success(); +} + +template +static LogicalResult verifyUnaryVecOp(UnaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getInput().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shape"); + return success(); +} + +LogicalResult VexpOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VlnOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VsqrtOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VnegOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VreluOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } + +template +static LogicalResult verifyBinaryVecOp(BinaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shapes"); + return success(); +} + +LogicalResult VaddOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VsubOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VmulOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VdivOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VandOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VxorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VshlOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VshrOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VaddcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VsubcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VaddcsOp::verify() { return verifyCarryVecOpWithInput(*this); } +LogicalResult VsubcsOp::verify() { return verifyCarryVecOpWithInput(*this); } + +template +static LogicalResult verifyReductionVecOp(ReductionOp op) { + return verifyUnaryVecOp(op); +} + +template +static LogicalResult verifyGroupReductionVecOp(ReductionOp op) { + if (failed(verifyReductionVecOp(op))) + return failure(); + auto inputType = cast(op.getInput().getType()); + Type elemType = inputType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() < 16 || intType.getWidth() > 32) + return op.emitOpError( + "requires 16-bit or 32-bit integer vector element type"); + return success(); + } + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires i16/i32/f16/f32 vector element type"); + return success(); +} + +LogicalResult VcgaddOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgmaxOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgminOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcpaddOp::verify() { + if (failed(verifyReductionVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +template +static LogicalResult verifyLaneSelectOp(SelectOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + auto src0Type = cast(op.getSrc0().getType()); + auto src1Type = cast(op.getSrc1().getType()); + auto resultType = cast(op.getResult().getType()); + if (src0Type != resultType) + return op.emitOpError("requires src0 and result to have identical vector types"); + if (src1Type.getElementCount() != src0Type.getElementCount()) + return op.emitOpError("requires src0/src1 to have identical element counts"); + auto src1ElemType = dyn_cast(src1Type.getElementType()); + if (!src1ElemType) + return op.emitOpError("requires src1 to use integer vector elements"); + if (src1ElemType.getWidth() != getIntOrFloatBitWidth(src0Type.getElementType())) + return op.emitOpError("requires src1 integer element width to match src0 element width"); + return success(); +} + +template +static LogicalResult verifyPairVecResults(PairOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(op, op.getHigh().getType(), "high result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getLow().getType() || + op.getLhs().getType() != op.getHigh().getType()) + return op.emitOpError("requires operands and results to share one vector type"); + return success(); +} + +template +static LogicalResult verifyPartVecOp(PartOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires operands and result to share one vector type"); + if (!isSupportedPartToken(op.getPart())) + return op.emitOpError("requires part to be LOWER or HIGHER"); + return success(); +} + +LogicalResult VselOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType() || + getSrc0().getType() != getResult().getType()) + return emitOpError("requires src0, src1, and result to have identical vector types"); + return success(); +} + +LogicalResult VselrOp::verify() { return verifyLaneSelectOp(*this); } +LogicalResult Vselrv2Op::verify() { return verifyLaneSelectOp(*this); } + +LogicalResult VsqzOp::verify() { return verifyUnaryVecOp(*this); } + +LogicalResult VusqzOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc().getType() != getResult().getType()) + return emitOpError("requires src and result to share one vector type"); + auto srcType = cast(getSrc().getType()); + auto elemType = dyn_cast(srcType.getElementType()); + if (!elemType) + return emitOpError("requires signed integer vector element type"); + if (elemType.isUnsigned()) + return emitOpError("requires signed integer vector element type"); + unsigned width = elemType.getWidth(); + if (width != 8 && width != 16 && width != 32) + return emitOpError("requires s8/s16/s32 vector element type"); + return success(); +} + +LogicalResult VpackOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); + auto srcType = cast(getSrc().getType()); + auto resultType = cast(getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return emitOpError("currently requires integer source and result element types"); + if (resultType.getElementCount() != srcType.getElementCount() * 2) + return emitOpError( + "requires result element count to be twice the source element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || resultWidth * 2 != srcWidth) + return emitOpError( + "requires result element width to be half the source element width"); + auto srcIntType = cast(srcElemType); + auto resultIntType = cast(resultElemType); + if (!resultIntType.isUnsigned()) + return emitOpError("requires unsigned result element type"); + if (!((srcIntType.getWidth() == 32 && resultIntType.getWidth() == 16) || + (srcIntType.getWidth() == 16 && resultIntType.getWidth() == 8))) + return emitOpError( + "currently supports only s32/u32 -> u16 and s16/u16 -> u8"); + return success(); +} + +template +static LogicalResult verifyUnpackVecOp(UnpackOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(op.getSrc().getType()); + auto resultType = cast(op.getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return op.emitOpError( + "currently requires integer source and result element types"); + if (srcType.getElementCount() != resultType.getElementCount() * 2) + return op.emitOpError( + "requires source element count to be twice the result element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || srcWidth * 2 != resultWidth) + return op.emitOpError( + "requires result element width to be twice the source element width"); + return success(); +} + +LogicalResult VsunpackOp::verify() { return verifyUnpackVecOp(*this); } +LogicalResult VzunpackOp::verify() { return verifyUnpackVecOp(*this); } + +static bool isSupportedCmpMode(StringRef mode) { + return mode == "eq" || mode == "ne" || mode == "lt" || mode == "le" || + mode == "gt" || mode == "ge"; +} + +LogicalResult VcmpOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType()) + return emitOpError("requires src0 and src1 to have identical vector types"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +LogicalResult VcmpsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(getSrc().getType()); + if (getScalar().getType() != srcType.getElementType()) + return emitOpError("requires scalar type to match source element type"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +ParseResult VtrcOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + OpAsmParser::UnresolvedOperand mask; + std::string roundModeToken; + NamedAttrList attrs; + Type inputType, maskType, resultType; + + if (parser.parseOperand(input) || parser.parseComma() || + parser.parseOperand(mask) || parser.parseComma() || + parser.parseKeywordOrString(&roundModeToken) || + parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseComma() || + parser.parseType(maskType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + auto normalized = normalizeRoundModeToken(roundModeToken); + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) + return parser.emitError(parser.getCurrentLocation()) + << "round mode must be one of R/A/F/C/Z or " + "ROUND_R/ROUND_A/ROUND_F/ROUND_C/ROUND_Z"; + + attrs.set("round_mode", parser.getBuilder().getStringAttr(*normalized)); + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands) || + parser.resolveOperand(mask, maskType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VtrcOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput() << ", " << getMask() << ", "; + Builder builder(getContext()); + auto normalized = normalizeRoundModeToken(getRoundMode()); + printer.printAttributeWithoutType( + builder.getStringAttr(normalized.value_or(getRoundMode()))); + printer.printOptionalAttrDict((*this)->getAttrs(), {"round_mode"}); + printer << " : " << getInput().getType() << ", " << getMask().getType() + << " -> " << getResult().getType(); +} + +LogicalResult VtrcOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("requires input and result to have identical vreg type"); + auto elemType = inputType.getElementType(); + if (!(elemType.isF16() || elemType.isF32() || elemType.isBF16())) + return emitOpError("requires f16/f32/bf16 vector element type"); + auto expectedGranularity = getVdupMaskGranularity(elemType); + if (!expectedGranularity) + return emitOpError("requires element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); + auto normalized = normalizeRoundModeToken(getRoundMode()); + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) + return emitOpError("round mode must be one of R/A/F/C/Z"); + return success(); +} + +ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + NamedAttrList attrs; + Type inputType, resultType; + + if (parser.parseOperand(input) || parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + Attribute legacyRndAttr = attrs.get("round_mode"); + Attribute rndAttr = attrs.get("rnd"); + if (legacyRndAttr && rndAttr) + return parser.emitError(parser.getCurrentLocation()) + << "rnd and round_mode cannot be specified together"; + + auto normalizeNamedStringAttr = + [&](StringRef sourceName, StringRef canonicalName, + auto normalizeFn) -> ParseResult { + Attribute rawAttr = attrs.get(sourceName); + if (!rawAttr) + return success(); + auto strAttr = dyn_cast(rawAttr); + if (!strAttr) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " must be a string literal"; + auto normalized = normalizeFn(strAttr.getValue()); + if (!normalized) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " has unsupported value '" << strAttr.getValue() + << "'"; + attrs.erase(sourceName); + attrs.set(canonicalName, parser.getBuilder().getStringAttr(*normalized)); + return success(); + }; + + if (failed(normalizeNamedStringAttr("round_mode", "rnd", + normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("rnd", "rnd", normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("sat", "sat", normalizeSaturationToken)) || + failed( + normalizeNamedStringAttr("part", "part", normalizeEvenOddPartToken))) + return failure(); + + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VcvtOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getInput().getType() << " -> " << getResult().getType(); +} + +LogicalResult VcvtOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + + VcvtElemKind inputElemKind = classifyVcvtElemType(inputType.getElementType()); + VcvtElemKind resultElemKind = classifyVcvtElemType(resultType.getElementType()); + auto contract = lookupVcvtContract(inputElemKind, resultElemKind); + if (!contract) + return emitOpError("unsupported vcvt source/result element type pair"); + + auto inputElemBits = getVcvtElemBitWidth(inputElemKind); + auto resultElemBits = getVcvtElemBitWidth(resultElemKind); + if (!inputElemBits || !resultElemBits) + return emitOpError("could not determine vcvt element bit width"); + if (inputType.getElementCount() * static_cast(*inputElemBits) != + resultType.getElementCount() * static_cast(*resultElemBits)) { + return emitOpError("requires source and result vectors to carry the same " + "total number of bits"); + } + + if (getRndAttr()) { + StringRef roundMode = *getRnd(); + if (!normalizeRoundModeToken(roundMode)) + return emitOpError("rnd must be one of R/A/F/C/Z/O"); + } + if (static_cast(getRndAttr()) != contract->requiresRnd) { + return contract->requiresRnd ? emitOpError("requires rnd attr for this vcvt type pair") + : emitOpError("rnd attr is not valid for this vcvt type pair"); + } + + if (getSatAttr()) { + StringRef sat = *getSat(); + if (!normalizeSaturationToken(sat)) + return emitOpError("sat must be SAT or NOSAT"); + } + if (static_cast(getSatAttr()) != contract->requiresSat) { + return contract->requiresSat ? emitOpError("requires sat attr for this vcvt type pair") + : emitOpError("sat attr is not valid for this vcvt type pair"); + } + + if (getPartAttr()) { + StringRef part = *getPart(); + if (!normalizeEvenOddPartToken(part)) + return emitOpError("part must be EVEN or ODD"); + } + if (static_cast(getPartAttr()) != contract->requiresPart) { + return contract->requiresPart ? emitOpError("requires part attr for this vcvt type pair") + : emitOpError("part attr is not valid for this vcvt type pair"); + } + + return success(); +} + +LogicalResult PdintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult PintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult VintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult VdintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult Vintlvv2Op::verify() { return verifyPartVecOp(*this); } +LogicalResult Vdintlvv2Op::verify() { return verifyPartVecOp(*this); } + +LogicalResult VmullOp::verify() { + if (failed(verifyPairVecResults(*this)) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + auto lhsType = cast(getLhs().getType()); + auto lhsElemType = dyn_cast(lhsType.getElementType()); + if (!lhsElemType) + return emitOpError("requires integer vector element type"); + if (lhsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +LogicalResult VmulaOp::verify() { + if (failed(verifyVRegTypeLike(*this, getAcc().getType(), "acc type")) || + failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getAcc().getType() != getLhs().getType() || + getAcc().getType() != getRhs().getType() || + getAcc().getType() != getResult().getType()) + return emitOpError("requires acc, lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyFloatBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyBinaryVecNoMaskOp(op))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + Type elemType = lhsType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecNoMaskOp(*this); } +LogicalResult VexpdiffOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + + auto inputType = cast(getInput().getType()); + auto maxType = cast(getMax().getType()); + auto resultType = cast(getResult().getType()); + if (inputType != maxType) + return emitOpError("requires input and max to share one vector type"); + + Type inputElemType = inputType.getElementType(); + if (!inputElemType.isF16() && !inputElemType.isF32()) + return emitOpError("requires f16 or f32 input vector element type"); + if (!resultType.getElementType().isF32()) + return emitOpError("requires f32 result vector element type"); + + auto inputBits = getVRegStorageBitWidth(inputType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!inputBits || !resultBits || *inputBits != *resultBits) + return emitOpError( + "requires source and result to preserve total vector storage width"); + + StringRef part = getPart(); + if (part != "EVEN" && part != "ODD") + return emitOpError("part must be EVEN or ODD"); + return success(); +} + +LogicalResult VaxpyOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + auto src0Type = cast(getSrc0().getType()); + auto src1Type = cast(getSrc1().getType()); + auto resultType = cast(getResult().getType()); + if (src0Type != src1Type || src0Type != resultType) + return emitOpError("requires src0, src1, and result to share one vector type"); + Type elemType = src0Type.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + if (getAlpha().getType() != elemType) + return emitOpError("requires alpha type to match vector element type"); + return success(); +} + +template +static LogicalResult verifyFusedConvVecOp(ConvOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + if (lhsType != rhsType) + return op.emitOpError("requires lhs and rhs to share one vector type"); + if (!isIntegerOrFloatLike(lhsType.getElementType()) || + !isIntegerOrFloatLike(resultType.getElementType())) + return op.emitOpError( + "requires integer or floating-point vector element types"); + auto lhsBits = getVRegStorageBitWidth(lhsType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!lhsBits || !resultBits || *lhsBits != *resultBits) + return op.emitOpError( + "requires source and result to preserve total vector storage width"); + return success(); +} + +LogicalResult VaddreluconvOp::verify() { + return verifyFusedConvVecOp(*this); +} +LogicalResult VmulconvOp::verify() { return verifyFusedConvVecOp(*this); } + +void Vldsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vldsx2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high result type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high results to share one vector type"); + if (!isSupportedVldx2DistToken(getDist())) + return emitOpError("requires a supported x2 load distribution token"); + if (failed(verifyVldsx2DistWidth( + getOperation(), getDist(), + cast(getLow().getType()).getElementType()))) + return failure(); + return success(); +} + +void VstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +template +static LogicalResult verifyVstsCommon(StoreOp op) { + if (failed(verifyVRegTypeLike(op, op.getValue().getType(), "value type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + + if (!isBufferLike(op.getDestination().getType())) + return op.emitOpError("requires a pointer-like destination"); + + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed destination"); + + if (std::optional dist = op.getDist(); + dist && !isSupportedVstsDistToken(*dist)) { + return op.emitOpError("requires a supported store distribution token"); + } + if (std::optional dist = op.getDist(); + dist && + failed(verifyVstsDistWidth( + op.getOperation(), *dist, + cast(op.getValue().getType()).getElementType()))) + return failure(); + + return success(); +} + +LogicalResult VstsOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VstsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstsPostOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (getUpdatedDestination().getType() != getDestination().getType()) + return emitOpError( + "requires updated destination result to match destination type"); + return success(); +} + +void Vstsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLowMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getHighMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vstsx2Op::verify() { + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low value type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high values to share one vector type"); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedVstsx2DistToken(getDist())) + return emitOpError("requires a supported x2 store distribution token"); + if (failed(verifyVstsx2DistWidth( + getOperation(), getDist(), + cast(getLow().getType()).getElementType()))) + return failure(); + return success(); +} + +void VscatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VscatterOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + auto offsetsType = dyn_cast(getOffsets().getType()); + auto valueType = dyn_cast(getValue().getType()); + if (!offsetsType || !valueType) + return emitOpError("value and offsets must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != valueType.getElementCount()) + return emitOpError("offset and value vectors must have the same element count"); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getActiveLanes().getType().isIndex()) + return emitOpError("active_lanes must be index"); + return success(); +} + +void VsldbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VsldbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void PstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void PstiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult PstiOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +LogicalResult PstsOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +void VsstbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VsstbOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void VstasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstasOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void VstarOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstarOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void PstuOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult PstuOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyMaskTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType()) || !isBufferLike(getBaseOut().getType())) + return emitOpError("requires pointer-like base and base_out"); + if (getBase().getType() != getBaseOut().getType()) + return emitOpError("requires base and base_out to have identical types"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + auto baseType = cast(getBase().getType()); + auto maskType = cast(getValue().getType()); + auto elemType = dyn_cast(baseType.getElementType()); + if (!elemType || elemType.isSigned() || (elemType.getWidth() != 16 && elemType.getWidth() != 32)) + return emitOpError("requires ui16/ui32 UB base type"); + if (maskType.isB16() && elemType.getWidth() != 16) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + if (maskType.isB32() && elemType.getWidth() != 32) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + return success(); +} + +void VstusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VstusOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + return success(); +} + +void VsturOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VsturOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + if (!isSupportedPostMode(getMode())) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} + +void CopyUbufToGmOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyUbufToGmOp::verify() { + return verifyCopyUbufToGmOp(*this, false); +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b82d227fe..89ee52e42 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -12,6 +12,15 @@ # See LICENSE in the root of the software repository for the full text of the License. add_mlir_dialect_library(PTOTransforms + HIVMIntrinsicNaming.cpp + VPTOLLVMEmitter.cpp + VPTOLLVMEmitterHelper.cpp + PTOVPTOExpandBridgeOps.cpp + PTOVPTOPtrBoundary.cpp + PTOToVPTO.cpp + PTOToVPTOLowering.cpp + PTOValidateVPTOIR.cpp + InsertSync/PTOInsertSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp @@ -46,6 +55,9 @@ add_mlir_dialect_library(PTOTransforms PTOPassesIncGen PTOOpsIncGen + LINK_COMPONENTS + Analysis + LINK_LIBS PUBLIC PTOIR MLIRIR @@ -59,7 +71,12 @@ add_mlir_dialect_library(PTOTransforms MLIRTransformUtils MLIRTransforms MLIRTensorDialect + MLIRSCFDialect MLIRSCFToEmitC + MLIRSCFToControlFlow + MLIRConvertToLLVMPass + MLIRTargetLLVMIRExport + MLIRToLLVMIRTranslationRegistration ) install(TARGETS PTOTransforms diff --git a/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp new file mode 100644 index 000000000..d87cb6867 --- /dev/null +++ b/lib/PTO/Transforms/HIVMIntrinsicNaming.cpp @@ -0,0 +1,561 @@ +//===- HIVMIntrinsicNaming.cpp - HIVM intrinsic selection -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/HIVMIntrinsicNaming.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; + +namespace mlir::pto { +namespace { + +static std::string getLocationString(Location loc) { + std::string storage; + llvm::raw_string_ostream os(storage); + loc.print(os); + return storage; +} + +static std::string sanitizeNameFragment(llvm::StringRef text) { + std::string out; + out.reserve(text.size()); + for (char c : text) { + if (std::isalnum(static_cast(c)) || c == '.' || c == '_') + out.push_back(c); + else + out.push_back('_'); + } + return out; +} + +static std::string printAttrText(Attribute attr) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << attr; + return storage; +} + +static std::string getElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return "unknown"; +} + +static std::string getVectorTypeFragment(Type type) { + auto vecType = dyn_cast(type); + if (!vecType) + return {}; + return ("v" + std::to_string(vecType.getElementCount()) + + getElementTypeFragment(vecType.getElementType())); +} + +static std::string getCopyElementFragment(Type type) { + auto ptrType = dyn_cast(type); + if (!ptrType) + return {}; + Type elementType = ptrType.getElementType(); + if (auto floatType = dyn_cast(elementType)) { + switch ((floatType.getWidth() + 7) / 8) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return {}; + } + } + if (auto intType = dyn_cast(elementType)) { + switch ((intType.getWidth() + 7) / 8) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return {}; + } + } + return {}; +} + +static std::string getOpMnemonic(Operation *op) { + return op->getName().stripDialect().str(); +} + +static IntrinsicSelection makeResolved(Operation *op, llvm::StringRef calleeName, + llvm::ArrayRef usedFields, + llvm::StringRef resultTypeFragment) { + IntrinsicSelection selection; + selection.resolved = true; + selection.sourceOpName = op->getName().getStringRef().str(); + selection.calleeName = calleeName.str(); + selection.usedFields.assign(usedFields.begin(), usedFields.end()); + selection.resultTypeFragment = resultTypeFragment.str(); + selection.location = getLocationString(op->getLoc()); + return selection; +} + +static IntrinsicSelection makeUnresolved(Operation *op, + llvm::StringRef familyOrOp, + llvm::StringRef candidateName, + llvm::ArrayRef usedFields, + llvm::ArrayRef missingFields, + llvm::StringRef resultTypeFragment) { + IntrinsicSelection selection; + selection.resolved = false; + selection.sourceOpName = op->getName().getStringRef().str(); + selection.candidateName = candidateName.str(); + selection.usedFields.assign(usedFields.begin(), usedFields.end()); + selection.missingFields.assign(missingFields.begin(), missingFields.end()); + selection.resultTypeFragment = resultTypeFragment.str(); + selection.location = getLocationString(op->getLoc()); + + std::string name = "__ptoas_hivm_unresolved."; + name += sanitizeNameFragment(familyOrOp); + if (!resultTypeFragment.empty()) { + name += "."; + name += sanitizeNameFragment(resultTypeFragment); + } + selection.placeholderName = std::move(name); + return selection; +} + +static FailureOr selectSyncLike(Operation *op) { + llvm::SmallVector usedFields; + usedFields.push_back("op=" + getOpMnemonic(op)); + + if (auto setFlag = dyn_cast(op)) { + usedFields.push_back("src_pipe=" + printAttrText(setFlag.getSrcPipe())); + usedFields.push_back("dst_pipe=" + printAttrText(setFlag.getDstPipe())); + usedFields.push_back("event=" + printAttrText(setFlag.getEventId())); + return makeResolved(op, "llvm.hivm.SET.FLAG.IMM", usedFields, ""); + } else if (auto waitFlag = dyn_cast(op)) { + usedFields.push_back("src_pipe=" + printAttrText(waitFlag.getSrcPipe())); + usedFields.push_back("dst_pipe=" + printAttrText(waitFlag.getDstPipe())); + usedFields.push_back("event=" + printAttrText(waitFlag.getEventId())); + return makeResolved(op, "llvm.hivm.WAIT.FLAG.IMM", usedFields, ""); + } else if (auto barrier = dyn_cast(op)) { + usedFields.push_back("pipe=" + printAttrText(barrier.getPipe())); + return makeResolved(op, "llvm.hivm.BARRIER", usedFields, ""); + } + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, ""); +} + +static FailureOr selectConfigLike(Operation *op) { + llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; + + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB", + usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB", usedFields, ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT", usedFields, + ""); + if (isa(op)) + return makeResolved(op, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT", usedFields, ""); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, + ""); +} + +static FailureOr selectPredicateIntrinsic(Operation *op) { + llvm::SmallVector usedFields; + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=8", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b8", usedFields, resultFragment); + } + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=16", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b16", usedFields, resultFragment); + } + if (auto pset = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pset.getResult().getType()); + usedFields = {"family=pset", "bitwidth=32", "result=" + resultFragment, + "pattern=i32"}; + return makeResolved(op, "llvm.hivm.pset.b32", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=8", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b8", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=16", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b16", usedFields, resultFragment); + } + if (auto pge = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(pge.getResult().getType()); + usedFields = {"family=pge", "bitwidth=32", "result=" + resultFragment, + "pattern=i32", "variant=i32_zero"}; + return makeResolved(op, "llvm.hivm.pge.b32", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=8", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b8.v300", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=16", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b16.v300", usedFields, resultFragment); + } + if (auto plt = dyn_cast(op)) { + const std::string resultFragment = + getVectorTypeFragment(plt.getMask().getType()); + usedFields = {"family=plt", "bitwidth=32", "result=" + resultFragment, + "variant=v300", "scalar=i32", "scalar_out=i32"}; + return makeResolved(op, "llvm.hivm.plt.b32.v300", usedFields, resultFragment); + } + + return failure(); +} + +} // namespace + +FailureOr selectLoadIntrinsic(Operation *op) { + if (auto vlds = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vlds.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vldsx1", "vector=" + vecFragment, "mode=NO_POST_UPDATE"}; + if (vlds.getDistAttr()) + usedFields.push_back("dist=" + (*vlds.getDist()).str()); + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vldsx1", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vldsx1"; + return makeUnresolved(op, "vldsx1", candidate, usedFields, missingFields, + vecFragment); + } + + if (auto vldsPost = dyn_cast(op)) { + const std::string vecFragment = + getVectorTypeFragment(vldsPost.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vldsx1", "variant=post", "vector=" + vecFragment, + "mode=POST_UPDATE"}; + if (vldsPost.getDistAttr()) + usedFields.push_back("dist=" + (*vldsPost.getDist()).str()); + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vldsx1.post", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vldsx1.post"; + return makeUnresolved(op, "vldsx1.post", candidate, usedFields, + missingFields, vecFragment); + } + + return failure(); +} + +FailureOr selectUnaryIntrinsic(Operation *op) { + auto vabs = dyn_cast(op); + if (vabs) { + const std::string vecFragment = getVectorTypeFragment(vabs.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vabs", "vector=" + vecFragment, "variant=x"}; + + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vabs.v64f32.x", usedFields, vecFragment); + + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + std::string candidate = "llvm.hivm.vabs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeUnresolved(op, "vabs", candidate, usedFields, missingFields, + vecFragment); + } + + if (auto vexp = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vexp.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vexp", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vexp"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto vdup = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vdup.getResult().getType()); + const bool vectorInput = isa(vdup.getInput().getType()); + const StringRef position = vdup.getPosition().value_or("LOWEST"); + const char *family = + vectorInput ? (position == "HIGHEST" ? "vdupm" : "vdup") : "vdups"; + llvm::SmallVector usedFields = { + "family=" + std::string(family), "vector=" + vecFragment, + "variant=z"}; + if (!vectorInput && !isa(vdup.getInput().getType())) { + llvm::SmallVector missingFields = {"scalar_input_vdup_mapping"}; + return makeUnresolved(op, "vdup", "llvm.hivm.vdups", usedFields, missingFields, + vecFragment); + } + std::string candidate = "llvm.hivm."; + candidate += family; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".z"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vadd", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vadd"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vsub", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vsub"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmul", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmul"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmax", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmax"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmuls", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmuls"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vadds", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vadds"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmaxs", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmaxs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vmins", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vmins"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vlrelu", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vlrelu"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vshls", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vshls"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + if (auto binary = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(binary.getResult().getType()); + llvm::SmallVector usedFields = { + "family=vshrs", "vector=" + vecFragment, "variant=x"}; + std::string candidate = "llvm.hivm.vshrs"; + if (!vecFragment.empty()) + candidate += "." + vecFragment + ".x"; + return makeResolved(op, candidate, usedFields, vecFragment); + } + + return failure(); +} + +FailureOr selectStoreIntrinsic(Operation *op) { + llvm::SmallVector usedFields; + llvm::SmallVector missingFields = {"confirmed_hivm_name"}; + + if (auto vsts = dyn_cast(op)) { + const std::string vecFragment = getVectorTypeFragment(vsts.getValue().getType()); + usedFields = {"family=vstsx1", "vector=" + vecFragment, + "predicate_source=explicit_mask", "mode=NO_POST_UPDATE"}; + if (vsts.getDistAttr()) + usedFields.push_back("dist=" + (*vsts.getDist()).str()); + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vstsx1", usedFields, vecFragment); + return makeUnresolved(op, "vstsx1", "llvm.hivm.vstsx1", usedFields, missingFields, + vecFragment); + } + + if (auto vstsPost = dyn_cast(op)) { + const std::string vecFragment = + getVectorTypeFragment(vstsPost.getValue().getType()); + usedFields = {"family=vstsx1", "variant=post", "vector=" + vecFragment, + "predicate_source=explicit_mask", "mode=POST_UPDATE"}; + if (vstsPost.getDistAttr()) + usedFields.push_back("dist=" + (*vstsPost.getDist()).str()); + if (vecFragment == "v64f32") + return makeResolved(op, "llvm.hivm.vstsx1.post", usedFields, + vecFragment); + std::string candidate = "llvm.hivm.vstsx1.post"; + return makeUnresolved(op, "vstsx1.post", candidate, usedFields, + missingFields, vecFragment); + } + + if (auto copy = dyn_cast(op)) { + std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); + usedFields = {"family=copy_gm_to_ubuf"}; + if (!elemFragment.empty()) + usedFields.push_back("element=" + elemFragment); + if (elemFragment == "u8" || elemFragment == "u16" || + elemFragment == "u32" || elemFragment == "f32") { + std::string callee = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2."; + callee += elemFragment; + callee += ".DV"; + return makeResolved(op, callee, usedFields, ""); + } + std::string candidate = "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2"; + if (!elemFragment.empty()) + candidate += "." + elemFragment + ".DV"; + missingFields.push_back("element_type_mapping"); + return makeUnresolved(op, "copy_gm_to_ubuf", candidate, usedFields, + missingFields, ""); + } + + if (auto copy = dyn_cast(op)) { + std::string elemFragment = getCopyElementFragment(copy.getSource().getType()); + usedFields = {"family=copy_ubuf_to_gm"}; + if (!elemFragment.empty()) + usedFields.push_back("element=" + elemFragment); + return makeResolved(op, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV", + usedFields, ""); + } + + if (isa(op)) { + usedFields = {"family=copy_ubuf_to_ubuf"}; + return makeUnresolved(op, "copy_ubuf_to_ubuf", "copy_ubuf_to_ubuf", + usedFields, missingFields, ""); + } + + return failure(); +} + +FailureOr selectIntrinsic(Operation *op) { + if (isa(op)) + return selectSyncLike(op); + + if (isa(op)) + return selectConfigLike(op); + + if (succeeded(selectLoadIntrinsic(op))) + return *selectLoadIntrinsic(op); + if (succeeded(selectUnaryIntrinsic(op))) + return *selectUnaryIntrinsic(op); + if (succeeded(selectPredicateIntrinsic(op))) + return *selectPredicateIntrinsic(op); + if (succeeded(selectStoreIntrinsic(op))) + return *selectStoreIntrinsic(op); + + llvm::SmallVector usedFields = {"op=" + getOpMnemonic(op)}; + llvm::SmallVector missingFields = {"family_mapping", + "confirmed_hivm_name"}; + return makeUnresolved(op, getOpMnemonic(op), "", usedFields, missingFields, + ""); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToVPTO.cpp b/lib/PTO/Transforms/PTOToVPTO.cpp new file mode 100644 index 000000000..1661e1fcf --- /dev/null +++ b/lib/PTO/Transforms/PTOToVPTO.cpp @@ -0,0 +1,604 @@ +//===- PTOToVPTO.cpp - PTO to VPTO pass wiring ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/Passes.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { + +#define GEN_PASS_DEF_PTOTOVPTO +#include "PTO/Transforms/Passes.h.inc" + +namespace { + + +FailureOr +parseVPTOLoweringStrategy(StringRef strategyName) { + if (strategyName == "post-update") + return VPTOLoweringStrategy::PostUpdate; + if (strategyName == "no-post-update") + return VPTOLoweringStrategy::NoPostUpdate; + return failure(); +} + +LogicalResult lowerTLOADOp(TLoadOp op, PatternRewriter &rewriter) { + return lowerTLOAD(op, rewriter); +} + +LogicalResult lowerTABSOp(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTABS(op, rewriter, strategy); +} + +LogicalResult lowerTADDOp(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTADD(op, rewriter, strategy); +} + +LogicalResult lowerTSUBOp(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSUB(op, rewriter, strategy); +} + +LogicalResult lowerTMULOp(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMUL(op, rewriter, strategy); +} + +LogicalResult lowerTDIVOp(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTDIV(op, rewriter, strategy); +} + +LogicalResult lowerTMAXOp(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMAX(op, rewriter, strategy); +} + +LogicalResult lowerTMINOp(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMIN(op, rewriter, strategy); +} + +LogicalResult lowerTANDOp(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTAND(op, rewriter, strategy); +} + +LogicalResult lowerTANDSOp(TAndSOp op, PatternRewriter &rewriter) { + return lowerTANDS(op, rewriter); +} + +LogicalResult lowerTOROp(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTOR(op, rewriter, strategy); +} + +LogicalResult lowerTORSOp(TOrSOp op, PatternRewriter &rewriter) { + return lowerTORS(op, rewriter); +} + +LogicalResult lowerTXOROp(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTXOR(op, rewriter, strategy); +} + +LogicalResult lowerTXORSOp(TXorSOp op, PatternRewriter &rewriter) { + return lowerTXORS(op, rewriter); +} + +LogicalResult lowerTEXPOp(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTEXP(op, rewriter, strategy); +} + +LogicalResult lowerTLOGOp(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTLOG(op, rewriter, strategy); +} + +LogicalResult lowerTSQRTOp(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSQRT(op, rewriter, strategy); +} + +LogicalResult lowerTRSQRTOp(TRsqrtOp op, PatternRewriter &rewriter) { + return lowerTRSQRT(op, rewriter); +} + +LogicalResult lowerTRECIPOp(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRECIP(op, rewriter, strategy); +} + +LogicalResult lowerTNEGOp(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTNEG(op, rewriter, strategy); +} + +LogicalResult lowerTLRELUOp(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTLRELU(op, rewriter, strategy); +} + +LogicalResult lowerTCIOp(TCIOp op, PatternRewriter &rewriter) { + return lowerTCI(op, rewriter); +} + +LogicalResult lowerTCVTOp(TCvtOp op, PatternRewriter &rewriter) { + return lowerTCVT(op, rewriter); +} + +LogicalResult lowerTCmpOp(TCmpOp op, PatternRewriter &rewriter) { + return lowerTCmp(op, rewriter); +} + +LogicalResult lowerTCmpSOp(TCmpSOp op, PatternRewriter &rewriter) { + return lowerTCmpS(op, rewriter); +} + +LogicalResult lowerTSelOp(TSelOp op, PatternRewriter &rewriter) { + return lowerTSel(op, rewriter); +} + +LogicalResult lowerTAddCOp(TAddCOp op, PatternRewriter &rewriter) { + return lowerTAddC(op, rewriter); +} + +LogicalResult lowerTAddSOp(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTAddS(op, rewriter, strategy); +} + +LogicalResult lowerTAddSCOp(TAddSCOp op, PatternRewriter &rewriter) { + return lowerTAddSC(op, rewriter); +} + +LogicalResult lowerTMinSOp(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMinS(op, rewriter, strategy); +} + +LogicalResult lowerTSubCOp(TSubCOp op, PatternRewriter &rewriter) { + return lowerTSubC(op, rewriter); +} + +LogicalResult lowerTSubSOp(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTSubS(op, rewriter, strategy); +} + +LogicalResult lowerTSubSCOp(TSubSCOp op, PatternRewriter &rewriter) { + return lowerTSubSC(op, rewriter); +} + +LogicalResult lowerTMaxSOp(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMaxS(op, rewriter, strategy); +} + +LogicalResult lowerTDivSOp(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTDivS(op, rewriter, strategy); +} + +LogicalResult lowerTMulSOp(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTMulS(op, rewriter, strategy); +} + +LogicalResult lowerTSelSOp(TSelSOp op, PatternRewriter &rewriter) { + return lowerTSelS(op, rewriter); +} + +LogicalResult lowerTRELUOp(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRELU(op, rewriter, strategy); +} + +LogicalResult lowerTNOTOp(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTNOT(op, rewriter, strategy); +} + +LogicalResult lowerTTRANSOp(TTransOp op, PatternRewriter &rewriter) { + return lowerTTRANS(op, rewriter); +} + +LogicalResult lowerTFILLPADOp(TFillPadOp op, PatternRewriter &rewriter) { + return lowerTFILLPAD(op, rewriter); +} + +LogicalResult lowerTFILLPADExpandOp(TFillPadExpandOp op, PatternRewriter &rewriter) { + return lowerTFILLPADExpand(op, rewriter); +} + +LogicalResult lowerTRowMaxOp(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowMax(op, rewriter, strategy); +} + +LogicalResult lowerTRowMinOp(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowMin(op, rewriter, strategy); +} + +LogicalResult lowerTRowSumOp(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowSum(op, rewriter, strategy); +} + +LogicalResult lowerTColMaxOp(TColMaxOp op, PatternRewriter &rewriter) { + return lowerTColMax(op, rewriter); +} + +LogicalResult lowerTColMinOp(TColMinOp op, PatternRewriter &rewriter) { + return lowerTColMin(op, rewriter); +} + +LogicalResult lowerTColSumOp(TColSumOp op, PatternRewriter &rewriter) { + return lowerTColSum(op, rewriter); +} + +LogicalResult lowerTRowExpandOp(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpand(op, rewriter, strategy); +} + +LogicalResult lowerTColExpandOp(TColExpandOp op, PatternRewriter &rewriter) { + return lowerTColExpand(op, rewriter); +} + +LogicalResult lowerTRowExpandMulOp(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandMul(op, rewriter, strategy); +} + +LogicalResult lowerTRowExpandDivOp(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandDiv(op, rewriter, strategy); +} + +LogicalResult lowerTRowExpandSubOp(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandSub(op, rewriter, strategy); +} + +LogicalResult lowerTPartAddOp(TPartAddOp op, PatternRewriter &rewriter) { + return lowerTPartAdd(op, rewriter); +} + +LogicalResult lowerTPartMaxOp(TPartMaxOp op, PatternRewriter &rewriter) { + return lowerTPartMax(op, rewriter); +} + +LogicalResult lowerTPartMinOp(TPartMinOp op, PatternRewriter &rewriter) { + return lowerTPartMin(op, rewriter); +} + +LogicalResult lowerTExpandSOp(TExpandsOp op, PatternRewriter &rewriter) { + return lowerTExpandS(op, rewriter); +} + +LogicalResult lowerTGatherOp(TGatherOp op, PatternRewriter &rewriter) { + return lowerTGather(op, rewriter); +} + +LogicalResult lowerTGatherBOp(TGatherBOp op, PatternRewriter &rewriter) { + return lowerTGatherB(op, rewriter); +} + +LogicalResult lowerTScatterOp(TScatterOp op, PatternRewriter &rewriter) { + return lowerTScatter(op, rewriter); +} + +LogicalResult lowerTMrgSortOp(TMrgSortOp op, PatternRewriter &rewriter) { + return lowerTMrgSort(op, rewriter); +} + +LogicalResult lowerTSort32Op(TSort32Op op, PatternRewriter &rewriter) { + return lowerTSort32(op, rewriter); +} + +LogicalResult lowerTSTOREOp(TStoreOp op, PatternRewriter &rewriter) { + return lowerTSTORE(op, rewriter); +} + +LogicalResult lowerSetFlagOp(SetFlagOp op, PatternRewriter &rewriter) { + return lowerSetFlag(op, rewriter); +} + +LogicalResult lowerWaitFlagOp(WaitFlagOp op, PatternRewriter &rewriter) { + return lowerWaitFlag(op, rewriter); +} + +LogicalResult lowerBarrierOp(BarrierOp op, PatternRewriter &rewriter) { + return lowerBarrier(op, rewriter); +} + +LogicalResult lowerGetBufOp(GetBufOp op, PatternRewriter &rewriter) { + return lowerGetBuf(op, rewriter); +} + +LogicalResult lowerRlsBufOp(RlsBufOp op, PatternRewriter &rewriter) { + return lowerRlsBuf(op, rewriter); +} + +LogicalResult lowerTensorPipelineOp(Operation *op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + rewriter.setInsertionPoint(op); + + LogicalResult lowered = success(); + if (auto tload = dyn_cast(op)) + lowered = lowerTLOADOp(tload, rewriter); + else if (auto tabs = dyn_cast(op)) + lowered = lowerTABSOp(tabs, rewriter, strategy); + else if (auto tadd = dyn_cast(op)) + lowered = lowerTADDOp(tadd, rewriter, strategy); + else if (auto tsub = dyn_cast(op)) + lowered = lowerTSUBOp(tsub, rewriter, strategy); + else if (auto tmul = dyn_cast(op)) + lowered = lowerTMULOp(tmul, rewriter, strategy); + else if (auto tdiv = dyn_cast(op)) + lowered = lowerTDIVOp(tdiv, rewriter, strategy); + else if (auto tmax = dyn_cast(op)) + lowered = lowerTMAXOp(tmax, rewriter, strategy); + else if (auto tmin = dyn_cast(op)) + lowered = lowerTMINOp(tmin, rewriter, strategy); + else if (auto tand = dyn_cast(op)) + lowered = lowerTANDOp(tand, rewriter, strategy); + else if (auto tands = dyn_cast(op)) + lowered = lowerTANDSOp(tands, rewriter); + else if (auto tor = dyn_cast(op)) + lowered = lowerTOROp(tor, rewriter, strategy); + else if (auto tors = dyn_cast(op)) + lowered = lowerTORSOp(tors, rewriter); + else if (auto txor = dyn_cast(op)) + lowered = lowerTXOROp(txor, rewriter, strategy); + else if (auto txors = dyn_cast(op)) + lowered = lowerTXORSOp(txors, rewriter); + else if (auto texp = dyn_cast(op)) + lowered = lowerTEXPOp(texp, rewriter, strategy); + else if (auto tlog = dyn_cast(op)) + lowered = lowerTLOGOp(tlog, rewriter, strategy); + else if (auto tsqrt = dyn_cast(op)) + lowered = lowerTSQRTOp(tsqrt, rewriter, strategy); + else if (auto trsqr = dyn_cast(op)) + lowered = lowerTRSQRTOp(trsqr, rewriter); + else if (auto trecip = dyn_cast(op)) + lowered = lowerTRECIPOp(trecip, rewriter, strategy); + else if (auto tneg = dyn_cast(op)) + lowered = lowerTNEGOp(tneg, rewriter, strategy); + else if (auto tlrelu = dyn_cast(op)) + lowered = lowerTLRELUOp(tlrelu, rewriter, strategy); + else if (auto tci = dyn_cast(op)) + lowered = lowerTCIOp(tci, rewriter); + else if (auto tcvt = dyn_cast(op)) + lowered = lowerTCVTOp(tcvt, rewriter); + else if (auto tcmp = dyn_cast(op)) + lowered = lowerTCmpOp(tcmp, rewriter); + else if (auto tcmps = dyn_cast(op)) + lowered = lowerTCmpSOp(tcmps, rewriter); + else if (auto tsel = dyn_cast(op)) + lowered = lowerTSelOp(tsel, rewriter); + else if (auto taddc = dyn_cast(op)) + lowered = lowerTAddCOp(taddc, rewriter); + else if (auto tadds = dyn_cast(op)) + lowered = lowerTAddSOp(tadds, rewriter, strategy); + else if (auto taddsc = dyn_cast(op)) + lowered = lowerTAddSCOp(taddsc, rewriter); + else if (auto tmins = dyn_cast(op)) + lowered = lowerTMinSOp(tmins, rewriter, strategy); + else if (auto tsubc = dyn_cast(op)) + lowered = lowerTSubCOp(tsubc, rewriter); + else if (auto tsubs = dyn_cast(op)) + lowered = lowerTSubSOp(tsubs, rewriter, strategy); + else if (auto tsubsc = dyn_cast(op)) + lowered = lowerTSubSCOp(tsubsc, rewriter); + else if (auto tmaxs = dyn_cast(op)) + lowered = lowerTMaxSOp(tmaxs, rewriter, strategy); + else if (auto tdivs = dyn_cast(op)) + lowered = lowerTDivSOp(tdivs, rewriter, strategy); + else if (auto tmuls = dyn_cast(op)) + lowered = lowerTMulSOp(tmuls, rewriter, strategy); + else if (auto tsels = dyn_cast(op)) + lowered = lowerTSelSOp(tsels, rewriter); + else if (auto trelu = dyn_cast(op)) + lowered = lowerTRELUOp(trelu, rewriter, strategy); + else if (auto tnot = dyn_cast(op)) + lowered = lowerTNOTOp(tnot, rewriter, strategy); + else if (auto ttrans = dyn_cast(op)) + lowered = lowerTTRANSOp(ttrans, rewriter); + else if (auto tfillpad = dyn_cast(op)) + lowered = lowerTFILLPADOp(tfillpad, rewriter); + else if (auto tfillpadExpand = dyn_cast(op)) + lowered = lowerTFILLPADExpandOp(tfillpadExpand, rewriter); + else if (auto trowmax = dyn_cast(op)) + lowered = lowerTRowMaxOp(trowmax, rewriter, strategy); + else if (auto trowmin = dyn_cast(op)) + lowered = lowerTRowMinOp(trowmin, rewriter, strategy); + else if (auto trowsum = dyn_cast(op)) + lowered = lowerTRowSumOp(trowsum, rewriter, strategy); + else if (auto tcolmax = dyn_cast(op)) + lowered = lowerTColMaxOp(tcolmax, rewriter); + else if (auto tcolmin = dyn_cast(op)) + lowered = lowerTColMinOp(tcolmin, rewriter); + else if (auto tcolsum = dyn_cast(op)) + lowered = lowerTColSumOp(tcolsum, rewriter); + else if (auto trowexpand = dyn_cast(op)) + lowered = lowerTRowExpandOp(trowexpand, rewriter, strategy); + else if (auto tcolexpand = dyn_cast(op)) + lowered = lowerTColExpandOp(tcolexpand, rewriter); + else if (auto trowexpandmul = dyn_cast(op)) + lowered = lowerTRowExpandMulOp(trowexpandmul, rewriter, strategy); + else if (auto trowexpanddiv = dyn_cast(op)) + lowered = lowerTRowExpandDivOp(trowexpanddiv, rewriter, strategy); + else if (auto trowexpandsub = dyn_cast(op)) + lowered = lowerTRowExpandSubOp(trowexpandsub, rewriter, strategy); + else if (auto tpartadd = dyn_cast(op)) + lowered = lowerTPartAddOp(tpartadd, rewriter); + else if (auto tpartmax = dyn_cast(op)) + lowered = lowerTPartMaxOp(tpartmax, rewriter); + else if (auto tpartmin = dyn_cast(op)) + lowered = lowerTPartMinOp(tpartmin, rewriter); + else if (auto texpands = dyn_cast(op)) + lowered = lowerTExpandSOp(texpands, rewriter); + else if (auto tgather = dyn_cast(op)) + lowered = lowerTGatherOp(tgather, rewriter); + else if (auto tgatherb = dyn_cast(op)) + lowered = lowerTGatherBOp(tgatherb, rewriter); + else if (auto tscatter = dyn_cast(op)) + lowered = lowerTScatterOp(tscatter, rewriter); + else if (auto tmrgsort = dyn_cast(op)) + lowered = lowerTMrgSortOp(tmrgsort, rewriter); + else if (auto tsort32 = dyn_cast(op)) + lowered = lowerTSort32Op(tsort32, rewriter); + else if (auto tstore = dyn_cast(op)) + lowered = lowerTSTOREOp(tstore, rewriter); + else + return success(); + + if (failed(lowered)) + return failure(); + + rewriter.eraseOp(op); + return success(); +} + +LogicalResult lowerResidualPTOOp(Operation *op, PatternRewriter &rewriter) { + rewriter.setInsertionPoint(op); + + LogicalResult lowered = success(); + if (auto setFlag = dyn_cast(op)) + lowered = lowerSetFlagOp(setFlag, rewriter); + else if (auto waitFlag = dyn_cast(op)) + lowered = lowerWaitFlagOp(waitFlag, rewriter); + else if (auto barrier = dyn_cast(op)) + lowered = lowerBarrierOp(barrier, rewriter); + else if (auto getBuf = dyn_cast(op)) + lowered = lowerGetBufOp(getBuf, rewriter); + else if (auto rlsBuf = dyn_cast(op)) + lowered = lowerRlsBufOp(rlsBuf, rewriter); + else if (isa(op) && op->use_empty()) + lowered = success(); + else + return success(); + + if (failed(lowered)) + return failure(); + + rewriter.eraseOp(op); + return success(); +} + +struct PTOToVPTOPass : public impl::PTOToVPTOBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOToVPTOPass) + + PTOToVPTOPass() = default; + + explicit PTOToVPTOPass(StringRef loweringStrategy) { + this->loweringStrategy = loweringStrategy.str(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + FailureOr loweringStrategy = + parseVPTOLoweringStrategy(this->loweringStrategy); + if (failed(loweringStrategy)) { + module.emitError() + << "unsupported pto-lowering-strategy: " << this->loweringStrategy + << " (expected post-update or no-post-update)"; + signalPassFailure(); + return; + } + SmallVector tensorPipelineOps; + SmallVector residualPTOOps; + module.walk([&](Operation *op) { + if (isa(op)) + tensorPipelineOps.push_back(op); + else if (isa(op)) + residualPTOOps.push_back(op); + }); + + PatternRewriter rewriter(&getContext()); + bool sawFailure = false; + for (Operation *op : tensorPipelineOps) { + if (!op->getBlock()) + continue; + if (failed(lowerTensorPipelineOp(op, rewriter, *loweringStrategy))) + sawFailure = true; + } + for (Operation *op : residualPTOOps) { + if (!op->getBlock()) + continue; + if (failed(lowerResidualPTOOp(op, rewriter))) + sawFailure = true; + } + + bool erasedDeadScaffold = true; + while (erasedDeadScaffold) { + erasedDeadScaffold = false; + SmallVector deadScaffoldOps; + module.walk([&](Operation *op) { + if ((isa(op)) && op->use_empty()) + deadScaffoldOps.push_back(op); + }); + for (Operation *op : deadScaffoldOps) { + if (!op->getBlock()) + continue; + rewriter.setInsertionPoint(op); + rewriter.eraseOp(op); + erasedDeadScaffold = true; + } + } + + // Keep the backend mainline memref-first through PTOToVPTO. Pointer ABI + // bridging belongs to the emission boundary, where text/LLVM emitters can + // materialize the required ptr-only signature on a cloned module. + + if (sawFailure) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createLowerPTOToVPTOPass() { + return std::make_unique(); +} + +std::unique_ptr createLowerPTOToVPTOPass(StringRef loweringStrategy) { + return std::make_unique(loweringStrategy); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOToVPTOLowering.cpp b/lib/PTO/Transforms/PTOToVPTOLowering.cpp new file mode 100644 index 000000000..89d4f33e7 --- /dev/null +++ b/lib/PTO/Transforms/PTOToVPTOLowering.cpp @@ -0,0 +1,7292 @@ +//===- PTOToVPTOLowering.cpp - PTO to VPTO lowering helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLowering.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/ADT/APFloat.h" + +#include +#include + +namespace mlir { +namespace pto { + +namespace { + +constexpr StringLiteral kLoweredLoopScopeAttrName = "llvm.loop.aivector_scope"; + +struct ResolvedTensorView { + Value root; + Attribute layoutAttr; + SmallVector shape; + SmallVector strides; + OpFoldResult offsetElems; +}; + +struct VecNdTransferPlan { + Value outerCount; + Value outerSrcStrideElems; + Value outerDstStrideElems; + Value loop2Size; + Value loop1Size; + Value loop2FirstStrideBytes; + Value loop2SecondStrideBytes; + Value loop1FirstStrideBytes; + Value loop1SecondStrideBytes; + Value nBurst; + Value lenBurst; + Value firstStrideBytes; + Value secondStrideBytes; +}; + +struct VPTORowReduceContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOColReduceContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + bool isBinary = false; + Value tmp; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOPartContract { + StringRef family; + VPTOTileDomain src0Domain = VPTOTileDomain::Vec; + VPTOTileDomain src1Domain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef src0Layout; + StringRef src1Layout; + StringRef dstLayout; + Type elementType; + Value src0ValidRowsValue; + Value src0ValidColsValue; + Value src1ValidRowsValue; + Value src1ValidColsValue; + Value dstValidRowsValue; + Value dstValidColsValue; + int64_t src0ValidRows = ShapedType::kDynamic; + int64_t src0ValidCols = ShapedType::kDynamic; + int64_t src1ValidRows = ShapedType::kDynamic; + int64_t src1ValidCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +struct VPTOExpandContract { + StringRef family; + VPTOTileDomain srcDomain = VPTOTileDomain::Vec; + VPTOTileDomain dstDomain = VPTOTileDomain::Vec; + StringRef srcLayout; + StringRef dstLayout; + Type elementType; + Value srcValidRowsValue; + Value srcValidColsValue; + Value dstValidRowsValue; + Value dstValidColsValue; + int64_t srcValidRows = ShapedType::kDynamic; + int64_t srcValidCols = ShapedType::kDynamic; + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + VPTOLoopScopeContract loopScope; +}; + +StringRef inferVecTransferLayoutFromTile(StringRef explicitLayout, + StringRef tileLayout) { + if (explicitLayout != "nd") + return explicitLayout; + if (tileLayout == "col_major") + return "dn"; + return "nd"; +} + +int64_t getElementByteSize(Type type); +Value materializeIndexValue(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc); +Value materializeI64Value(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc); + +LogicalResult emitUnresolvedInstalledA5BaselineError(Operation *op, + StringRef family) { + return op->emitOpError() + << family + << " lowering is intentionally unresolved until the installed A5 PTO " + "helper baseline is located and traced"; +} + +std::optional getConstInt(Value value) { + if (!value) + return std::nullopt; + + if (auto constIndex = value.getDefiningOp()) + return constIndex.value(); + if (auto constInt = value.getDefiningOp()) + return constInt.value(); + if (auto constOp = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +std::optional getConstInt(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + return std::nullopt; + } + return getConstInt(cast(value)); +} + +Value materializeIndexOfr(OpFoldResult value, PatternRewriter &rewriter, + Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt()); + return {}; + } + Value v = cast(value); + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + return {}; +} + +Value materializeI64Ofr(OpFoldResult value, PatternRewriter &rewriter, + Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt(), 64); + return {}; + } + return materializeI64Value(cast(value), ShapedType::kDynamic, rewriter, loc); +} + +Value materializeIndexBuilder(OpFoldResult value, PatternRewriter &rewriter, Location loc) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) + return rewriter.create(loc, intAttr.getInt()); + return {}; + } + Value v = cast(value); + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + return {}; +} + +Value createI64Mul(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { + if (!lhs || !rhs) + return {}; + if (std::optional lhsConst = getConstInt(lhs)) { + if (std::optional rhsConst = getConstInt(rhs)) + return rewriter.create(loc, (*lhsConst) * (*rhsConst), 64); + } + return rewriter.create(loc, lhs, rhs); +} + +Value createI64Add(Value lhs, Value rhs, PatternRewriter &rewriter, Location loc) { + if (!lhs || !rhs) + return {}; + if (std::optional lhsConst = getConstInt(lhs)) { + if (std::optional rhsConst = getConstInt(rhs)) + return rewriter.create(loc, (*lhsConst) + (*rhsConst), 64); + } + return rewriter.create(loc, lhs, rhs); +} + +OpFoldResult addOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, + Location loc) { + if (auto lhsConst = getConstInt(lhs)) { + if (auto rhsConst = getConstInt(rhs)) + return rewriter.getIndexAttr((*lhsConst) + (*rhsConst)); + } + Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); + Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); + if (!lhsValue || !rhsValue) + return {}; + return rewriter.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult multiplyOfr(OpFoldResult lhs, OpFoldResult rhs, PatternRewriter &rewriter, + Location loc) { + if (auto lhsConst = getConstInt(lhs)) { + if (auto rhsConst = getConstInt(rhs)) + return rewriter.getIndexAttr((*lhsConst) * (*rhsConst)); + } + Value lhsValue = materializeIndexBuilder(lhs, rewriter, loc); + Value rhsValue = materializeIndexBuilder(rhs, rewriter, loc); + if (!lhsValue || !rhsValue) + return {}; + return rewriter.create(loc, lhsValue, rhsValue).getResult(); +} + +bool resolveTensorView(Value value, ResolvedTensorView &info, PatternRewriter &rewriter, + Location loc) { + if (!value) + return false; + + if (auto part = value.getDefiningOp()) { + if (!resolveTensorView(part.getSource(), info, rewriter, loc)) + return false; + SmallVector offsets; + offsets.reserve(part.getOffsets().size()); + for (Value offset : part.getOffsets()) + offsets.push_back(offset); + if (offsets.size() != info.strides.size()) + return false; + OpFoldResult totalOffset = info.offsetElems; + for (auto [offset, stride] : llvm::zip(offsets, info.strides)) { + OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); + if (!term) + return false; + totalOffset = addOfr(totalOffset, term, rewriter, loc); + if (!totalOffset) + return false; + } + info.offsetElems = totalOffset; + info.shape.clear(); + for (Value size : part.getSizes()) + info.shape.push_back(size); + return true; + } + + if (auto source = value.getDefiningOp()) { + info.root = source.getPtr(); + info.layoutAttr = source.getLayoutAttr(); + info.shape.assign(source.getShape().begin(), source.getShape().end()); + info.strides.assign(source.getStrides().begin(), source.getStrides().end()); + info.offsetElems = rewriter.getIndexAttr(0); + return true; + } + + if (auto subview = value.getDefiningOp()) { + ResolvedTensorView parent; + Value source = subview.getSource(); + if (auto reinterpret = source.getDefiningOp()) { + Value root = reinterpret.getSource(); + while (true) { + if (auto cast = root.getDefiningOp()) { + root = cast.getSource(); + continue; + } + break; + } + parent.root = root; + if (Attribute layout = reinterpret->getAttr("layout")) + parent.layoutAttr = layout; + auto parentShapes = + getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); + auto parentStrides = + getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); + auto offsets = + getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); + parent.shape.assign(parentShapes.begin(), parentShapes.end()); + parent.strides.assign(parentStrides.begin(), parentStrides.end()); + parent.offsetElems = + offsets.empty() ? OpFoldResult(rewriter.getIndexAttr(0)) : offsets.front(); + } else if (!resolveTensorView(source, parent, rewriter, loc)) { + return false; + } + + if (parent.strides.empty()) { + auto sourceType = dyn_cast(source.getType()); + if (!sourceType) + return false; + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(sourceType, strides, offset))) { + strides.assign(sourceType.getRank(), 1); + int64_t running = 1; + for (int i = sourceType.getRank() - 1; i >= 0; --i) { + strides[i] = running; + int64_t dim = sourceType.getShape()[i]; + if (dim != ShapedType::kDynamic) + running *= dim; + } + } + for (int64_t stride : strides) + parent.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); + parent.offsetElems = rewriter.getIndexAttr(offset); + parent.root = source; + } + + info = parent; + if (subview.getMixedOffsets().size() != info.strides.size()) + return false; + + OpFoldResult totalOffset = info.offsetElems; + for (auto [offset, stride] : llvm::zip(subview.getMixedOffsets(), info.strides)) { + OpFoldResult term = multiplyOfr(offset, stride, rewriter, loc); + if (!term) + return false; + totalOffset = addOfr(totalOffset, term, rewriter, loc); + if (!totalOffset) + return false; + } + + SmallVector newStrides; + newStrides.reserve(info.strides.size()); + for (auto [srcStride, step] : llvm::zip(info.strides, subview.getMixedStrides())) { + OpFoldResult product = multiplyOfr(srcStride, step, rewriter, loc); + if (!product) + return false; + newStrides.push_back(product); + } + + info.offsetElems = totalOffset; + info.shape.assign(subview.getMixedSizes().begin(), subview.getMixedSizes().end()); + info.strides = std::move(newStrides); + return true; + } + + if (auto reinterpret = value.getDefiningOp()) { + Value root = reinterpret.getSource(); + while (true) { + if (auto cast = root.getDefiningOp()) { + root = cast.getSource(); + continue; + } + if (auto unrealized = root.getDefiningOp()) { + if (!unrealized.getInputs().empty()) { + root = unrealized.getInputs().front(); + continue; + } + } + break; + } + info.root = root; + if (Attribute layout = reinterpret->getAttr("layout")) + info.layoutAttr = layout; + auto reinterpretShapes = + getMixedValues(reinterpret.getStaticSizes(), reinterpret.getSizes(), rewriter); + auto reinterpretStrides = + getMixedValues(reinterpret.getStaticStrides(), reinterpret.getStrides(), rewriter); + auto offsets = + getMixedValues(reinterpret.getStaticOffsets(), reinterpret.getOffsets(), rewriter); + info.shape.assign(reinterpretShapes.begin(), reinterpretShapes.end()); + info.strides.assign(reinterpretStrides.begin(), reinterpretStrides.end()); + if (!offsets.empty()) { + if (offsets.size() != 1) + return false; + info.offsetElems = offsets.front(); + } else { + info.offsetElems = rewriter.getIndexAttr(0); + } + return true; + } + + if (auto cast = value.getDefiningOp()) + return resolveTensorView(cast.getSource(), info, rewriter, loc); + + if (auto memrefType = dyn_cast(value.getType())) { + info.root = value; + info.shape.clear(); + for (int64_t dim : memrefType.getShape()) + info.shape.push_back(rewriter.getIndexAttr(dim == ShapedType::kDynamic ? 1 : dim)); + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + strides.assign(memrefType.getRank(), 1); + int64_t running = 1; + for (int i = memrefType.getRank() - 1; i >= 0; --i) { + strides[i] = running; + int64_t dim = memrefType.getShape()[i]; + if (dim != ShapedType::kDynamic) + running *= dim; + } + offset = 0; + } + info.strides.clear(); + for (int64_t stride : strides) + info.strides.push_back(rewriter.getIndexAttr(stride == ShapedType::kDynamic ? 1 : stride)); + info.offsetElems = rewriter.getIndexAttr(offset); + return true; + } + + return false; +} + +void normalizeMixedGlobalShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride, + PatternRewriter &rewriter, Location loc) { + constexpr int64_t kRank = 5; + globalShape.assign(kRank, rewriter.getIndexAttr(1)); + globalStride.assign(kRank, rewriter.getIndexAttr(1)); + + size_t rank = std::min(shape.size(), strides.size()); + rank = std::min(rank, kRank); + size_t base = kRank - rank; + for (size_t i = 0; i < rank; ++i) { + globalShape[base + i] = shape[shape.size() - rank + i]; + globalStride[base + i] = strides[strides.size() - rank + i]; + } + + for (int i = static_cast(kRank) - 2; i >= 0; --i) { + if (i >= static_cast(base)) + continue; + OpFoldResult product = multiplyOfr(globalStride[i + 1], globalShape[i + 1], rewriter, loc); + if (!product) + product = rewriter.getIndexAttr(ShapedType::kDynamic); + globalStride[i] = product; + } +} + +Value adjustPointerByElemOffset(Value ptr, Value elemOffsetI64, int64_t elemBytes, + PatternRewriter &rewriter, Location loc) { + if (!ptr || !elemOffsetI64 || elemBytes <= 0) + return {}; + + Value offset = elemOffsetI64.getType().isIndex() + ? rewriter.create( + loc, rewriter.getI64Type(), elemOffsetI64) + : elemOffsetI64; + Value byteOffset = offset; + if (elemBytes != 1) { + Value elemBytesValue = rewriter.create(loc, elemBytes, 64); + byteOffset = createI64Mul(offset, elemBytesValue, rewriter, loc); + } + if (auto ptrType = dyn_cast(ptr.getType())) { + auto bytePtrType = PtrType::get(rewriter.getContext(), rewriter.getI8Type(), + ptrType.getMemorySpace()); + Value bytePtr = ptrType == bytePtrType + ? ptr + : rewriter.create(loc, bytePtrType, ptr).getResult(); + Value byteOffsetIndex = + byteOffset.getType().isIndex() + ? byteOffset + : rewriter.create(loc, rewriter.getIndexType(), + byteOffset); + return rewriter.create(loc, bytePtrType, bytePtr, byteOffsetIndex); + } + return {}; +} + +Value castPtrToElementType(Value ptr, Type elementType, PatternRewriter &rewriter, + Location loc) { + auto ptrType = dyn_cast_or_null(ptr.getType()); + if (!ptrType || !elementType) + return {}; + auto targetType = + PtrType::get(rewriter.getContext(), elementType, ptrType.getMemorySpace()); + if (targetType == ptrType) + return ptr; + return rewriter.create(loc, targetType, ptr).getResult(); +} + +Type getCopyTransferElementType(Type elementType, Builder &builder) { + if (getElementByteSize(elementType) == 8) + return builder.getI32Type(); + return elementType; +} + +LogicalResult buildVecNdLoadPlan(ArrayRef shape, + ArrayRef strides, int64_t tileCols, + Value validColsValue, int64_t validCols, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileCols == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape3 = toI64(globalShape[3]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride3 = toI64(globalStride[3]); + Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape3 || !gStride0 || !gStride1 || + !gStride2 || !gStride3 || !validColsI64) + return failure(); + + Value tileColsI64 = rewriter.create(loc, tileCols, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value dstStride2 = createI64Mul(gShape3, tileColsI64, rewriter, loc); + Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); + Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = gStride0; + plan.outerDstStrideElems = dstStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape3; + plan.lenBurst = createI64Mul(validColsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride3, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileColsI64, elemBytesI64, rewriter, loc); + return success(); +} + +LogicalResult buildVecDnLoadPlan(ArrayRef shape, + ArrayRef strides, int64_t tileRows, + Value validRowsValue, int64_t validRows, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileRows == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, + rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape4 = toI64(globalShape[4]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride4 = toI64(globalStride[4]); + Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || + !gStride2 || !gStride4 || !validRowsI64) + return failure(); + + Value tileRowsI64 = rewriter.create(loc, tileRows, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value dstStride2 = createI64Mul(gShape4, tileRowsI64, rewriter, loc); + Value dstStride1 = createI64Mul(gShape2, dstStride2, rewriter, loc); + Value dstStride0 = createI64Mul(gShape1, dstStride1, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = gStride0; + plan.outerDstStrideElems = dstStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = createI64Mul(dstStride1, elemBytesI64, rewriter, loc); + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = createI64Mul(dstStride2, elemBytesI64, rewriter, loc); + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape4; + plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); + return success(); +} + +LogicalResult buildVecNdStorePlan(ArrayRef shape, + ArrayRef strides, int64_t tileCols, + Value validColsValue, int64_t validCols, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (failed(buildVecNdLoadPlan(shape, strides, tileCols, validColsValue, validCols, + elementType, rewriter, loc, plan))) + return failure(); + std::swap(plan.outerSrcStrideElems, plan.outerDstStrideElems); + std::swap(plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + std::swap(plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + return success(); +} + +LogicalResult buildVecDnStorePlan(ArrayRef shape, + ArrayRef strides, int64_t tileRows, + Value validRowsValue, int64_t validRows, + Type elementType, PatternRewriter &rewriter, + Location loc, VecNdTransferPlan &plan) { + if (tileRows == ShapedType::kDynamic) + return failure(); + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + SmallVector globalShape; + SmallVector globalStride; + normalizeMixedGlobalShapeAndStride(shape, strides, globalShape, globalStride, + rewriter, loc); + + auto toI64 = [&](OpFoldResult ofr) { return materializeI64Ofr(ofr, rewriter, loc); }; + Value gShape0 = toI64(globalShape[0]); + Value gShape1 = toI64(globalShape[1]); + Value gShape2 = toI64(globalShape[2]); + Value gShape4 = toI64(globalShape[4]); + Value gStride0 = toI64(globalStride[0]); + Value gStride1 = toI64(globalStride[1]); + Value gStride2 = toI64(globalStride[2]); + Value gStride4 = toI64(globalStride[4]); + Value validRowsI64 = materializeI64Value(validRowsValue, validRows, rewriter, loc); + if (!gShape0 || !gShape1 || !gShape2 || !gShape4 || !gStride0 || !gStride1 || + !gStride2 || !gStride4 || !validRowsI64) + return failure(); + + Value tileRowsI64 = rewriter.create(loc, tileRows, 64); + Value elemBytesI64 = rewriter.create(loc, elemBytes, 64); + Value outerSrcStride = + createI64Mul(createI64Mul(createI64Mul(gShape1, gShape2, rewriter, loc), + gShape4, rewriter, loc), + tileRowsI64, rewriter, loc); + Value loop1SrcStride = + createI64Mul(createI64Mul(tileRowsI64, gShape4, rewriter, loc), elemBytesI64, + rewriter, loc); + Value loop2SrcStride = + createI64Mul(createI64Mul(createI64Mul(gShape2, tileRowsI64, rewriter, loc), + gShape4, rewriter, loc), + elemBytesI64, rewriter, loc); + + plan.outerCount = gShape0; + plan.outerSrcStrideElems = outerSrcStride; + plan.outerDstStrideElems = gStride0; + plan.loop2Size = gShape1; + plan.loop1Size = gShape2; + plan.loop2FirstStrideBytes = loop2SrcStride; + plan.loop2SecondStrideBytes = createI64Mul(gStride1, elemBytesI64, rewriter, loc); + plan.loop1FirstStrideBytes = loop1SrcStride; + plan.loop1SecondStrideBytes = createI64Mul(gStride2, elemBytesI64, rewriter, loc); + plan.nBurst = gShape4; + plan.lenBurst = createI64Mul(validRowsI64, elemBytesI64, rewriter, loc); + plan.firstStrideBytes = createI64Mul(gStride4, elemBytesI64, rewriter, loc); + plan.secondStrideBytes = createI64Mul(tileRowsI64, elemBytesI64, rewriter, loc); + return success(); +} + +StringRef stringifyTileLayout(TileBufType type) { + if (auto layoutAttr = dyn_cast_or_null(type.getBLayoutAttr())) { + switch (layoutAttr.getValue()) { + case BLayout::RowMajor: + return "row_major"; + case BLayout::ColMajor: + return "col_major"; + } + } + return "row_major"; +} + +StringRef stringifyTileLayoutConfig(TileBufConfigAttr config) { + if (!config) + return "row_major"; + if (auto layoutAttr = dyn_cast_or_null(config.getBLayout())) { + switch (layoutAttr.getValue()) { + case BLayout::RowMajor: + return "row_major"; + case BLayout::ColMajor: + return "col_major"; + } + } + return "row_major"; +} + +StringRef stringifyPadModeAttr(PadModeAttr padMode) { + if (!padMode) + return "none"; + + switch (padMode.getPadmode()) { + case PadMode::PadNull: + return "none"; + case PadMode::PadFirstElem: + return "first_elem"; + case PadMode::PadValue: + return "value"; + } + return "none"; +} + +StringRef stringifyLayoutAttr(Attribute layoutAttr) { + if (auto attr = dyn_cast_or_null(layoutAttr)) + return stringifyLayout(attr.getLayout()); + return "nd"; +} + +PipeAttr stringifyPipeAttr(PipeAttr pipe, PatternRewriter &rewriter) { + return PipeAttr::get(rewriter.getContext(), pipe.getPipe()); +} + +EventAttr stringifyEventAttr(EventAttr event, PatternRewriter &rewriter) { + return EventAttr::get(rewriter.getContext(), event.getEvent()); +} + +StringRef stringifyCmpModeAttr(CmpModeAttr cmpMode) { + if (!cmpMode) + return "eq"; + switch (cmpMode.getValue()) { + case CmpMode::EQ: + return "eq"; + case CmpMode::NE: + return "ne"; + case CmpMode::LT: + return "lt"; + case CmpMode::LE: + return "le"; + case CmpMode::GT: + return "gt"; + case CmpMode::GE: + return "ge"; + } + return "eq"; +} + +StringRef stringifyElementTypeFragment(Type type) { + if (!type) + return "unknown"; + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isUnsigned()) + switch (intType.getWidth()) { + case 8: + return "u8"; + case 16: + return "u16"; + case 32: + return "u32"; + case 64: + return "u64"; + default: + break; + } + switch (intType.getWidth()) { + case 8: + return "s8"; + case 16: + return "s16"; + case 32: + return "s32"; + case 64: + return "s64"; + default: + break; + } + } + return "unknown"; +} + +StringRef stringifyCopyTransferTypeFragment(Type type) { + switch (getElementByteSize(type)) { + case 1: + return "u8"; + case 2: + return "u16"; + case 4: + case 8: + return "u32"; + default: + return stringifyElementTypeFragment(type); + } +} + +static bool isSupportedPackedCmp32ElementType(Type type) { + if (!type) + return false; + if (type.isF32()) + return true; + auto intType = dyn_cast(type); + return intType && intType.getWidth() == 32; +} + +VPTOTileDomain deriveTileDomain(Attribute memorySpace) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) { + switch (addrSpace.getAddressSpace()) { + case AddressSpace::ACC: + return VPTOTileDomain::Acc; + case AddressSpace::MAT: + return VPTOTileDomain::Mat; + case AddressSpace::VEC: + default: + return VPTOTileDomain::Vec; + } + } + if (auto intAttr = dyn_cast_or_null(memorySpace)) { + switch (intAttr.getInt()) { + case static_cast(AddressSpace::ACC): + return VPTOTileDomain::Acc; + case static_cast(AddressSpace::MAT): + return VPTOTileDomain::Mat; + default: + return VPTOTileDomain::Vec; + } + } + return VPTOTileDomain::Vec; +} + +void getValidShape(TileBufType type, int64_t &rows, int64_t &cols) { + ArrayRef validShape = type.getValidShape(); + rows = validShape.size() > 0 ? validShape[0] : ShapedType::kDynamic; + cols = validShape.size() > 1 ? validShape[1] : ShapedType::kDynamic; +} + +static std::pair getIfResultYieldedValues(Value value) { + auto result = dyn_cast(value); + if (!result) + return {Value(), Value()}; + auto ifOp = dyn_cast(result.getOwner()); + if (!ifOp) + return {Value(), Value()}; + unsigned resultNumber = result.getResultNumber(); + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield) + return {Value(), Value()}; + if (resultNumber >= thenYield.getNumOperands() || + resultNumber >= elseYield.getNumOperands()) + return {Value(), Value()}; + return {thenYield.getOperand(resultNumber), elseYield.getOperand(resultNumber)}; +} + +static bool equalOrBothNull(Value lhs, Value rhs) { + if (!lhs && !rhs) + return true; + if (!lhs || !rhs) + return false; + if (lhs == rhs) + return true; + auto lhsConst = getConstInt(lhs); + auto rhsConst = getConstInt(rhs); + return lhsConst && rhsConst && *lhsConst == *rhsConst; +} + +TileBufConfigAttr lookupTileConfig(Value value) { + if (!value) + return {}; + if (auto bind = value.getDefiningOp()) + return bind.getConfig(); + if (auto cast = value.getDefiningOp()) + return cast.getConfig().value_or(TileBufConfigAttr{}); + if (auto subview = value.getDefiningOp()) + return lookupTileConfig(subview.getSource()); + if (auto reinterpret = value.getDefiningOp()) + return lookupTileConfig(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return lookupTileConfig(cast.getSource()); + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + TileBufConfigAttr thenConfig = lookupTileConfig(thenValue); + TileBufConfigAttr elseConfig = lookupTileConfig(elseValue); + if (thenConfig && elseConfig && thenConfig == elseConfig) + return thenConfig; + } + return {}; +} + +bool hasStructuredTileDriver(Value value) { + if (!value) + return false; + if (isa(value.getType())) + return true; + if (value.getDefiningOp()) + return true; + if (auto subview = value.getDefiningOp()) + return hasStructuredTileDriver(subview.getSource()); + if (auto reinterpret = value.getDefiningOp()) + return hasStructuredTileDriver(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return hasStructuredTileDriver(cast.getSource()); + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + return hasStructuredTileDriver(thenValue) && hasStructuredTileDriver(elseValue); + } + return false; +} + +void lookupValidDims(Value value, Value &validRow, Value &validCol) { + if (!value) { + validRow = {}; + validCol = {}; + return; + } + if (auto bind = value.getDefiningOp()) { + validRow = bind.getValidRow(); + validCol = bind.getValidCol(); + return; + } + if (auto cast = value.getDefiningOp()) { + validRow = cast.getValidRow(); + validCol = cast.getValidCol(); + return; + } + if (auto subview = value.getDefiningOp()) { + lookupValidDims(subview.getSource(), validRow, validCol); + return; + } + if (auto reinterpret = value.getDefiningOp()) { + lookupValidDims(reinterpret.getSource(), validRow, validCol); + return; + } + if (auto cast = value.getDefiningOp()) { + lookupValidDims(cast.getSource(), validRow, validCol); + return; + } + if (auto [thenValue, elseValue] = getIfResultYieldedValues(value); + thenValue && elseValue) { + Value thenRow; + Value thenCol; + Value elseRow; + Value elseCol; + lookupValidDims(thenValue, thenRow, thenCol); + lookupValidDims(elseValue, elseRow, elseCol); + validRow = equalOrBothNull(thenRow, elseRow) ? thenRow : Value(); + validCol = equalOrBothNull(thenCol, elseCol) ? thenCol : Value(); + return; + } + validRow = {}; + validCol = {}; +} + +Type getElementType(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + return {}; +} + +Attribute getMemorySpace(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getMemorySpace(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getMemorySpace(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); + return {}; +} + +StringRef deriveTileLayout(Value value) { + if (auto tileType = dyn_cast(value.getType())) + return stringifyTileLayout(tileType); + return stringifyTileLayoutConfig(lookupTileConfig(value)); +} + +void deriveValidShape(Value value, int64_t &rows, int64_t &cols) { + if (auto tileType = dyn_cast(value.getType())) { + getValidShape(tileType, rows, cols); + return; + } + + Value validRow; + Value validCol; + lookupValidDims(value, validRow, validCol); + rows = getConstInt(validRow).value_or(ShapedType::kDynamic); + cols = getConstInt(validCol).value_or(ShapedType::kDynamic); + if (rows != ShapedType::kDynamic && cols != ShapedType::kDynamic) + return; + if (!hasStructuredTileDriver(value)) + return; + + auto shapedType = dyn_cast(value.getType()); + if (!shapedType || !shapedType.hasRank()) + return; + + ArrayRef shape = shapedType.getShape(); + if (shape.empty()) { + if (rows == ShapedType::kDynamic) + rows = 1; + if (cols == ShapedType::kDynamic) + cols = 1; + return; + } + if (shape.size() == 1) { + if (rows == ShapedType::kDynamic) + rows = 1; + if (cols == ShapedType::kDynamic) + cols = shape.front(); + return; + } + + if (cols == ShapedType::kDynamic) + cols = shape.back(); + if (rows == ShapedType::kDynamic) { + int64_t flatRows = 1; + for (int64_t dim : shape.drop_back()) { + if (dim == ShapedType::kDynamic) { + flatRows = ShapedType::kDynamic; + break; + } + flatRows *= dim; + } + rows = flatRows; + } +} + +void deriveValidShapeValues(Value value, Value &rows, Value &cols) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef validShape = tileType.getValidShape(); + rows = {}; + cols = {}; + (void)validShape; + lookupValidDims(value, rows, cols); + return; + } + lookupValidDims(value, rows, cols); +} + +void appendStaticSizes(ValueRange values, SmallVectorImpl &out, + bool &hasDynamic) { + out.clear(); + hasDynamic = false; + out.reserve(values.size()); + for (Value value : values) { + if (std::optional constant = getConstInt(value)) { + out.push_back(*constant); + continue; + } + out.push_back(ShapedType::kDynamic); + hasDynamic = true; + } +} + +int64_t getElementByteSize(Type type) { + if (auto floatType = dyn_cast(type)) + return (floatType.getWidth() + 7) / 8; + if (auto intType = dyn_cast(type)) + return (intType.getWidth() + 7) / 8; + return 0; +} + +Value materializeIndexValue(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc) { + if (maybeValue) + return maybeValue; + if (fallback != ShapedType::kDynamic) + return rewriter.create(loc, fallback); + return {}; +} + +Value materializeI64Value(Value maybeValue, int64_t fallback, + PatternRewriter &rewriter, Location loc) { + if (maybeValue) { + Type type = maybeValue.getType(); + if (type.isIndex()) + return rewriter.create(loc, rewriter.getI64Type(), maybeValue); + if (type.isInteger(64)) + return maybeValue; + if (auto intType = dyn_cast(type)) + return rewriter.create(loc, rewriter.getI64Type(), maybeValue); + } + if (fallback != ShapedType::kDynamic) + return rewriter.create(loc, fallback, 64); + return {}; +} + +void recordStaticValues(ValueRange values, SmallVectorImpl &out) { + out.clear(); + out.reserve(values.size()); + for (Value value : values) + out.push_back(getConstInt(value).value_or(ShapedType::kDynamic)); +} + +void recordStaticSizes(ArrayRef values, + SmallVectorImpl &out, bool &hasDynamic) { + out.clear(); + hasDynamic = false; + out.reserve(values.size()); + for (OpFoldResult value : values) { + if (auto attr = dyn_cast(value)) { + if (auto intAttr = dyn_cast(attr)) { + out.push_back(intAttr.getInt()); + continue; + } + } else if (std::optional constant = + getConstInt(cast(value))) { + out.push_back(*constant); + continue; + } + out.push_back(ShapedType::kDynamic); + hasDynamic = true; + } +} + +void mergeSubviewTrace(VPTOPartitionTrace &trace, ArrayRef offsets, + ArrayRef sizes, bool hasDynamicOffsets, + bool hasDynamicSizes) { + if (trace.offsets.empty()) { + trace.offsets.assign(offsets.begin(), offsets.end()); + trace.hasDynamicOffsets = hasDynamicOffsets; + } else { + size_t count = std::min(trace.offsets.size(), offsets.size()); + for (size_t i = 0; i < count; ++i) { + if (trace.offsets[i] == ShapedType::kDynamic || + offsets[i] == ShapedType::kDynamic) { + trace.offsets[i] = ShapedType::kDynamic; + trace.hasDynamicOffsets = true; + continue; + } + trace.offsets[i] += offsets[i]; + } + trace.hasDynamicOffsets = trace.hasDynamicOffsets || hasDynamicOffsets; + } + + trace.sizes.assign(sizes.begin(), sizes.end()); + trace.hasDynamicSizes = hasDynamicSizes; +} + +Value resolveTensorViewBase(Value value, Attribute &layoutAttr, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + if (!value) + return {}; + + if (auto part = value.getDefiningOp()) { + return resolveTensorViewBase(part.getSource(), layoutAttr, shape, strides); + } + + if (auto source = value.getDefiningOp()) { + layoutAttr = source.getLayoutAttr(); + auto tensorType = dyn_cast(source.getResult().getType()); + shape.assign(tensorType.getShape().begin(), tensorType.getShape().end()); + recordStaticValues(source.getStrides(), strides); + return source.getPtr(); + } + + if (auto subview = value.getDefiningOp()) { + Value base = + resolveTensorViewBase(subview.getSource(), layoutAttr, shape, strides); + if (shape.empty()) { + bool hasDynamicSizes = false; + recordStaticSizes(subview.getMixedSizes(), shape, hasDynamicSizes); + } + return base ? base : value; + } + + if (auto reinterpret = value.getDefiningOp()) { + if (Attribute layout = reinterpret->getAttr("layout")) + layoutAttr = layout; + if (shape.empty()) { + bool hasDynamicSizes = false; + recordStaticSizes(reinterpret.getMixedSizes(), shape, hasDynamicSizes); + } + if (strides.empty()) { + bool hasDynamicStrides = false; + recordStaticSizes(reinterpret.getMixedStrides(), strides, + hasDynamicStrides); + } + Value base = + resolveTensorViewBase(reinterpret.getSource(), layoutAttr, shape, strides); + return base ? base : value; + } + + if (auto cast = value.getDefiningOp()) { + Value base = + resolveTensorViewBase(cast.getSource(), layoutAttr, shape, strides); + return base ? base : value; + } + + if (auto memrefType = dyn_cast(value.getType())) { + if (shape.empty()) + shape.assign(memrefType.getShape().begin(), memrefType.getShape().end()); + if (strides.empty()) { + int64_t offset = 0; + if (failed(mlir::getStridesAndOffset(memrefType, strides, offset))) + strides.assign(shape.size(), ShapedType::kDynamic); + } + return value; + } + + return {}; +} + +pto::VRegType getVPTOVRegType(MLIRContext *context, Type elementType) { + unsigned bitWidth = 0; + if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + else if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + + if (bitWidth == 0 || 2048 % bitWidth != 0) + return {}; + return pto::VRegType::get(context, 2048 / bitWidth, elementType); +} + +pto::MaskType getVPTOMaskType(MLIRContext *context, StringRef granularity) { + return pto::MaskType::get(context, granularity); +} + +pto::MaskType getVPTOMaskTypeForElementType(MLIRContext *context, + Type elementType) { + unsigned bitWidth = 0; + if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + else if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + + switch (bitWidth) { + case 8: + return getVPTOMaskType(context, "b8"); + case 16: + return getVPTOMaskType(context, "b16"); + case 32: + return getVPTOMaskType(context, "b32"); + default: + return {}; + } +} + +ArrayAttr asI64ArrayAttr(Builder &builder, ArrayRef values) { + SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getI64IntegerAttr(value)); + return builder.getArrayAttr(attrs); +} + +void normalizeToPTOGlobalShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride) { + constexpr int64_t kRank = 5; + globalShape.assign(kRank, 1); + globalStride.assign(kRank, 1); + + size_t shapeRank = std::min(shape.size(), kRank); + size_t strideRank = std::min(strides.size(), kRank); + size_t rank = std::min(shapeRank, strideRank); + size_t base = kRank - rank; + + for (size_t i = 0; i < rank; ++i) { + globalShape[base + i] = shape[shape.size() - rank + i]; + globalStride[base + i] = strides[strides.size() - rank + i]; + } + + for (int i = static_cast(kRank) - 2; i >= 0; --i) { + if (i >= static_cast(base)) + continue; + if (globalStride[i + 1] == ShapedType::kDynamic || + globalShape[i + 1] == ShapedType::kDynamic) { + globalStride[i] = ShapedType::kDynamic; + continue; + } + globalStride[i] = globalStride[i + 1] * globalShape[i + 1]; + } +} + +int64_t packLoopStrideConfig(int64_t first, int64_t second) { + return (static_cast(first) << 40) | static_cast(second); +} + +int64_t packLoopSizeConfig(int64_t loop2, int64_t loop1) { + return (static_cast(loop2) << 21) | static_cast(loop1); +} + +LogicalResult deriveVecNDTransferConfig(ArrayRef shape, + ArrayRef strides, + StringRef tileLayout, Type elementType, + int64_t validRows, int64_t validCols, + SmallVectorImpl &globalShape, + SmallVectorImpl &globalStride, + int64_t &nBurst, int64_t &lenBurst, + int64_t &gmStrideBytes, + int64_t &ubStrideBytes, + int64_t &loop1Size, + int64_t &loop2Size, + int64_t &loop1FirstStrideBytes, + int64_t &loop1SecondStrideBytes, + int64_t &loop2FirstStrideBytes, + int64_t &loop2SecondStrideBytes) { + if (tileLayout != "row_major") + return failure(); + + int64_t elemBytes = getElementByteSize(elementType); + if (elemBytes <= 0) + return failure(); + + normalizeToPTOGlobalShapeAndStride(shape, strides, globalShape, globalStride); + if (globalShape.size() != 5 || globalStride.size() != 5) + return failure(); + if (llvm::any_of(globalShape, [](int64_t v) { return v == ShapedType::kDynamic; }) || + llvm::any_of(globalStride, [](int64_t v) { return v == ShapedType::kDynamic; })) + return failure(); + nBurst = globalShape[3]; + lenBurst = (validCols == ShapedType::kDynamic) ? ShapedType::kDynamic + : validCols * elemBytes; + gmStrideBytes = globalStride[3] * elemBytes; + ubStrideBytes = globalShape[4] * elemBytes; + + int64_t dstStride2 = globalShape[3] * validCols; + int64_t dstStride1 = globalShape[2] * dstStride2; + + loop2Size = globalShape[1]; + loop1Size = globalShape[2]; + loop2FirstStrideBytes = dstStride1 * elemBytes; + loop2SecondStrideBytes = globalStride[1] * elemBytes; + loop1FirstStrideBytes = dstStride2 * elemBytes; + loop1SecondStrideBytes = globalStride[2] * elemBytes; + return success(); +} + +std::pair getStaticTileRowsCols(Value value) { + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (shape.size() >= 2) + return {shape[shape.size() - 2], shape[shape.size() - 1]}; + } + return {ShapedType::kDynamic, ShapedType::kDynamic}; +} + +Value materializeStaticOrDynamicDimAsIndex(Value value, int64_t dim, + unsigned dimPos, + PatternRewriter &rewriter, + Location loc) { + if (dim != ShapedType::kDynamic) + return rewriter.create(loc, dim); + if (isa(value.getType())) + return rewriter.create(loc, value, dimPos); + return {}; +} + +LogicalResult materializeShapeBackedValidShapeValues(Value value, Value &rows, + Value &cols, + PatternRewriter &rewriter, + Location loc) { + rows = {}; + cols = {}; + + auto shapedType = dyn_cast(value.getType()); + if (!shapedType || !shapedType.hasRank() || !hasStructuredTileDriver(value)) + return failure(); + + ArrayRef shape = shapedType.getShape(); + if (shape.empty()) { + rows = rewriter.create(loc, 1); + cols = rewriter.create(loc, 1); + return success(); + } + if (shape.size() == 1) { + rows = rewriter.create(loc, 1); + cols = materializeStaticOrDynamicDimAsIndex(value, shape.front(), 0, rewriter, loc); + return success(cols != nullptr); + } + + cols = materializeStaticOrDynamicDimAsIndex(value, shape.back(), shape.size() - 1, + rewriter, loc); + if (!cols) + return failure(); + + Value flatRows = rewriter.create(loc, 1); + for (auto [idx, dim] : llvm::enumerate(shape.drop_back())) { + Value dimValue = + materializeStaticOrDynamicDimAsIndex(value, dim, idx, rewriter, loc); + if (!dimValue) + return failure(); + flatRows = rewriter.create(loc, flatRows, dimValue); + } + rows = flatRows; + return success(); +} + +LogicalResult resolveExecutionValidShape(Value carrier, Value &rowsValue, + Value &colsValue, int64_t &rows, + int64_t &cols, + PatternRewriter &rewriter, + Location loc) { + rowsValue = materializeIndexValue(rowsValue, rows, rewriter, loc); + colsValue = materializeIndexValue(colsValue, cols, rewriter, loc); + if (rowsValue && colsValue) + return success(); + + if (succeeded(materializeShapeBackedValidShapeValues(carrier, rowsValue, colsValue, + rewriter, loc))) { + deriveValidShape(carrier, rows, cols); + return success(rowsValue && colsValue); + } + return failure(); +} + +Attribute getGmMemorySpace(MLIRContext *context) { + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +AddressSpaceAttr getNormalizedPtrMemorySpace(Attribute memorySpace, + MLIRContext *context) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return AddressSpaceAttr::get(context, + static_cast(intAttr.getInt())); + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +Value materializeMemRefView(Value value, ArrayRef shape, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + auto memrefType = + MemRefType::get(shape, elementType, AffineMap(), memorySpace); + if (value.getType() == memrefType) + return value; + return rewriter + .create( + loc, TypeRange(ArrayRef{memrefType}), value) + .getResult(0); +} + +Value materializeTileBufferView(Value value, PatternRewriter &rewriter, + Location loc) { + if (auto memrefType = dyn_cast(value.getType())) + return value; + + auto tileType = dyn_cast(value.getType()); + if (!tileType) + return {}; + + return materializeMemRefView(value, tileType.getShape(), tileType.getElementType(), + tileType.getMemorySpace(), rewriter, loc); +} + +} // namespace + +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + auto ptrMemorySpace = + getNormalizedPtrMemorySpace(memorySpace, rewriter.getContext()); + auto ptrType = PtrType::get(rewriter.getContext(), elementType, ptrMemorySpace); + + if (value.getType() == ptrType) + return value; + + if (auto bind = value.getDefiningOp()) + return materializeBufferPointer(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + if (auto cast = value.getDefiningOp()) { + if (cast.getAddrs().empty()) + return {}; + return rewriter.create(loc, ptrType, cast.getAddrs().front()) + .getResult(); + } + + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + auto memrefType = dyn_cast_or_null(memrefValue.getType()); + if (!memrefValue || !memrefType) + return {}; + return rewriter.create(loc, ptrType, memrefValue).getResult(); +} + +namespace { + +Value materializeBufferLikeAddress(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + if (auto bind = value.getDefiningOp()) + return materializeBufferLikeAddress(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + // Keep memref semantics through the VPTO mainline whenever possible. + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + if (memrefValue && isa(memrefValue.getType())) + return memrefValue; + + return materializeBufferPointer(value, elementType, memorySpace, rewriter, loc); +} + +Value offsetBufferPointer(Value basePtr, Type elementType, Value elementOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + if (auto ptrType = dyn_cast(basePtr.getType())) { + Value offsetIndex = + elementOffset.getType().isIndex() + ? elementOffset + : rewriter.create(loc, + rewriter.getIndexType(), + elementOffset); + return rewriter.create(loc, ptrType, basePtr, offsetIndex); + } + return {}; +} + +Value buildPackedCountI64(PatternRewriter &rewriter, Location loc, + ArrayRef counts) { + Value packed = rewriter.create(loc, 0, 64); + for (auto [idx, count] : llvm::enumerate(counts)) { + Value countI64 = count.getType().isIndex() + ? rewriter.create( + loc, rewriter.getI64Type(), count) + : count; + if (idx != 0) { + Value shift = rewriter.create(loc, idx * 16, 64); + countI64 = rewriter.create(loc, countI64, shift); + } + packed = rewriter.create(loc, packed, countI64); + } + return packed; +} + +Value buildCeilDivPositiveI64(PatternRewriter &rewriter, Location loc, Value lhs, + int64_t rhs) { + Value rhsValue = rewriter.create(loc, rhs, 64); + Value rhsMinusOne = rewriter.create(loc, rhs - 1, 64); + Value biased = rewriter.create(loc, lhs, rhsMinusOne); + return rewriter.create(loc, biased, rhsValue); +} + +VPTOPartitionTrace extractPartitionTrace(Value value) { + VPTOPartitionTrace trace; + if (auto part = value.getDefiningOp()) { + appendStaticSizes(part.getOffsets(), trace.offsets, trace.hasDynamicOffsets); + appendStaticSizes(part.getSizes(), trace.sizes, trace.hasDynamicSizes); + return trace; + } + if (auto subview = value.getDefiningOp()) { + trace = extractPartitionTrace(subview.getSource()); + SmallVector offsets; + SmallVector sizes; + bool hasDynamicOffsets = false; + bool hasDynamicSizes = false; + recordStaticSizes(subview.getMixedOffsets(), offsets, hasDynamicOffsets); + recordStaticSizes(subview.getMixedSizes(), sizes, hasDynamicSizes); + mergeSubviewTrace(trace, offsets, sizes, hasDynamicOffsets, hasDynamicSizes); + return trace; + } + if (auto reinterpret = value.getDefiningOp()) + return extractPartitionTrace(reinterpret.getSource()); + if (auto cast = value.getDefiningOp()) + return extractPartitionTrace(cast.getSource()); + if (auto unrealized = value.getDefiningOp()) { + if (!unrealized.getInputs().empty()) + return extractPartitionTrace(unrealized.getInputs().front()); + } + return trace; +} + +VPTOLoadContract extractTLoadContract(TLoadOp op) { + VPTOLoadContract contract; + contract.trace = extractPartitionTrace(op.getSrc()); + contract.elementType = getElementType(op.getDst()); + + Attribute layoutAttr; + Value base = resolveTensorViewBase(op.getSrc(), layoutAttr, contract.sourceShape, + contract.sourceStrides); + (void)base; + contract.sourceLayout = stringifyLayoutAttr(layoutAttr); + + contract.tileLayout = deriveTileLayout(op.getDst()); + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + contract.padMode = stringifyPadModeAttr(op.getPadModeAttr()); + contract.padValue = op.getPadValue(); + contract.leftPaddingNum = op.getLeftPaddingNum(); + contract.rightPaddingNum = op.getRightPaddingNum(); + contract.initOutBuffer = op.getInitOutBuffer(); + contract.initCondition = op.getInitCondition(); + return contract; +} + +VPTOUnaryContract extractTAbsContract(TAbsOp op) { + VPTOUnaryContract contract; + contract.family = "abs"; + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.tileLayout = deriveTileLayout(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getSrc()); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOBinaryContract buildBinaryContract(StringRef family, Value src0) { + VPTOBinaryContract contract; + contract.family = family; + contract.tileDomain = deriveTileDomain(getMemorySpace(src0)); + contract.tileLayout = deriveTileLayout(src0); + deriveValidShapeValues(src0, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src0, contract.validRows, contract.validCols); + contract.elementType = getElementType(src0); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOBinaryContract extractTAddContract(TAddOp op) { + return buildBinaryContract("add", op.getSrc0()); +} + +VPTOBinaryContract extractTSubContract(TSubOp op) { + return buildBinaryContract("sub", op.getSrc0()); +} + +VPTOBinaryContract extractTMulContract(TMulOp op) { + return buildBinaryContract("mul", op.getSrc0()); +} + +VPTOBinaryContract extractTDivContract(TDivOp op) { + return buildBinaryContract("div", op.getSrc0()); +} + +VPTOUnaryContract buildUnaryContract(StringRef family, Value src) { + VPTOUnaryContract contract; + contract.family = family; + contract.tileDomain = deriveTileDomain(getMemorySpace(src)); + contract.tileLayout = deriveTileLayout(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + contract.elementType = getElementType(src); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOUnaryContract extractTExpContract(TExpOp op) { + return buildUnaryContract("exp", op.getSrc()); +} + +VPTOUnaryContract extractTLogContract(TLogOp op) { + return buildUnaryContract("log", op.getSrc()); +} + +VPTOUnaryContract extractTSqrtContract(TSqrtOp op) { + return buildUnaryContract("sqrt", op.getSrc()); +} + +VPTOUnaryContract extractTRecipContract(TRecipOp op) { + return buildUnaryContract("recip", op.getSrc()); +} + +VPTOUnaryContract extractTReluContract(TReluOp op) { + return buildUnaryContract("relu", op.getSrc()); +} + +VPTOUnaryContract extractTNotContract(TNotOp op) { + return buildUnaryContract("not", op.getSrc()); +} + +static FailureOr stringifyA5RoundMode(TCvtOp op, + PatternRewriter &rewriter) { + switch (op.getRmode()) { + case RoundMode::NONE: + case RoundMode::RINT: + case RoundMode::CAST_RINT: + return rewriter.getStringAttr("ROUND_R"); + case RoundMode::ROUND: + return rewriter.getStringAttr("ROUND_A"); + case RoundMode::FLOOR: + return rewriter.getStringAttr("ROUND_F"); + case RoundMode::CEIL: + return rewriter.getStringAttr("ROUND_C"); + case RoundMode::TRUNC: + return rewriter.getStringAttr("ROUND_Z"); + case RoundMode::ODD: + return rewriter.getStringAttr("ROUND_O"); + } + return failure(); +} + +enum class VPTOCvtLoweringKind { + Vtrc, + F32ToBF16, + F16ToF32, + BF16ToF32, +}; + +static FailureOr classifyA5CvtLowering(Type srcElemType, + Type dstElemType) { + if (srcElemType.isF32() && dstElemType.isF32()) + return VPTOCvtLoweringKind::Vtrc; + if (srcElemType.isF32() && dstElemType.isBF16()) + return VPTOCvtLoweringKind::F32ToBF16; + if (srcElemType.isF16() && dstElemType.isF32()) + return VPTOCvtLoweringKind::F16ToF32; + if (srcElemType.isBF16() && dstElemType.isF32()) + return VPTOCvtLoweringKind::BF16ToF32; + return failure(); +} + +VPTOUnaryContract extractTExpandSContract(TExpandsOp op) { + VPTOUnaryContract contract; + contract.family = "expands"; + contract.tileDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.tileLayout = deriveTileLayout(op.getDst()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, + contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getDst()); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOExpandContract extractTRowExpandContract(TRowExpandOp op) { + VPTOExpandContract contract; + contract.family = "rowexpand"; + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.srcLayout = deriveTileLayout(op.getSrc()); + contract.dstLayout = deriveTileLayout(op.getDst()); + contract.elementType = getElementType(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, + contract.srcValidColsValue); + deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); + deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, + contract.dstValidColsValue); + deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOExpandContract extractTColExpandContract(TColExpandOp op) { + VPTOExpandContract contract; + contract.family = "colexpand"; + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + contract.dstDomain = deriveTileDomain(getMemorySpace(op.getDst())); + contract.srcLayout = deriveTileLayout(op.getSrc()); + contract.dstLayout = deriveTileLayout(op.getDst()); + contract.elementType = getElementType(op.getSrc()); + deriveValidShapeValues(op.getSrc(), contract.srcValidRowsValue, + contract.srcValidColsValue); + deriveValidShape(op.getSrc(), contract.srcValidRows, contract.srcValidCols); + deriveValidShapeValues(op.getDst(), contract.dstValidRowsValue, + contract.dstValidColsValue); + deriveValidShape(op.getDst(), contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTORowReduceContract extractTRowReduceContract(Value src, Value dst, + StringRef family) { + VPTORowReduceContract contract; + contract.family = family; + contract.srcDomain = deriveTileDomain(getMemorySpace(src)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.srcLayout = deriveTileLayout(src); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + int64_t dstRows = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTORowReduceContract extractTRowMaxContract(TRowMaxOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmax"); +} + +VPTORowReduceContract extractTRowMinContract(TRowMinOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowmin"); +} + +VPTORowReduceContract extractTRowSumContract(TRowSumOp op) { + return extractTRowReduceContract(op.getSrc(), op.getDst(), "rowsum"); +} + +VPTOColReduceContract extractTColReduceContract(Value src, Value dst, + StringRef family) { + VPTOColReduceContract contract; + contract.family = family; + contract.srcDomain = deriveTileDomain(getMemorySpace(src)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.srcLayout = deriveTileLayout(src); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(src); + deriveValidShapeValues(src, contract.validRowsValue, contract.validColsValue); + deriveValidShape(src, contract.validRows, contract.validCols); + deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOColReduceContract extractTColMaxContract(TColMaxOp op) { + return extractTColReduceContract(op.getSrc(), op.getDst(), "colmax"); +} + +VPTOColReduceContract extractTColMinContract(TColMinOp op) { + return extractTColReduceContract(op.getSrc(), op.getDst(), "colmin"); +} + +VPTOColReduceContract extractTColSumContract(TColSumOp op) { + VPTOColReduceContract contract = + extractTColReduceContract(op.getSrc(), op.getDst(), "colsum"); + contract.isBinary = op.getIsBinary(); + contract.tmp = op.getTmp(); + return contract; +} + +VPTOPartContract extractTPartContract(Value src0, Value src1, Value dst, + StringRef family) { + VPTOPartContract contract; + contract.family = family; + contract.src0Domain = deriveTileDomain(getMemorySpace(src0)); + contract.src1Domain = deriveTileDomain(getMemorySpace(src1)); + contract.dstDomain = deriveTileDomain(getMemorySpace(dst)); + contract.src0Layout = deriveTileLayout(src0); + contract.src1Layout = deriveTileLayout(src1); + contract.dstLayout = deriveTileLayout(dst); + contract.elementType = getElementType(dst); + deriveValidShapeValues(src0, contract.src0ValidRowsValue, contract.src0ValidColsValue); + deriveValidShapeValues(src1, contract.src1ValidRowsValue, contract.src1ValidColsValue); + deriveValidShapeValues(dst, contract.dstValidRowsValue, contract.dstValidColsValue); + deriveValidShape(src0, contract.src0ValidRows, contract.src0ValidCols); + deriveValidShape(src1, contract.src1ValidRows, contract.src1ValidCols); + deriveValidShape(dst, contract.dstValidRows, contract.dstValidCols); + contract.loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + contract.loopScope.loweredAttr = kLoweredLoopScopeAttrName; + contract.loopScope.loopDepth = 0; + return contract; +} + +VPTOPartContract extractTPartAddContract(TPartAddOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partadd"); +} + +VPTOPartContract extractTPartMaxContract(TPartMaxOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmax"); +} + +VPTOPartContract extractTPartMinContract(TPartMinOp op) { + return extractTPartContract(op.getSrc0(), op.getSrc1(), op.getDst(), "partmin"); +} + +VPTOStoreContract extractTStoreContract(TStoreOp op) { + VPTOStoreContract contract; + contract.trace = extractPartitionTrace(op.getDst()); + + contract.srcDomain = deriveTileDomain(getMemorySpace(op.getSrc())); + deriveValidShapeValues(op.getSrc(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getSrc(), contract.validRows, contract.validCols); + contract.elementType = getElementType(op.getSrc()); + + Attribute layoutAttr; + Value base = resolveTensorViewBase(op.getDst(), layoutAttr, + contract.destinationShape, + contract.destinationStrides); + (void)base; + contract.destinationLayout = stringifyLayoutAttr(layoutAttr); + return contract; +} + +void attachLoadContractAttrs(Operation *op, const VPTOLoadContract &contract) { + Builder builder(op->getContext()); + SmallVector globalShape; + SmallVector globalStride; + normalizeToPTOGlobalShapeAndStride(contract.sourceShape, contract.sourceStrides, + globalShape, globalStride); + op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); + op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); +} + +void attachStoreContractAttrs(Operation *op, const VPTOStoreContract &contract) { + Builder builder(op->getContext()); + SmallVector globalShape; + SmallVector globalStride; + normalizeToPTOGlobalShapeAndStride(contract.destinationShape, + contract.destinationStrides, globalShape, + globalStride); + op->setAttr("g_shape", asI64ArrayAttr(builder, globalShape)); + op->setAttr("g_strides", asI64ArrayAttr(builder, globalStride)); +} + +LogicalResult lowerUnsupportedAccStore(Location loc) { + emitError(loc) << "TSTORE ACC lowering TODO for vpto backend"; + return failure(); +} + +LogicalResult lowerUnsupportedMatStore(Location loc) { + emitError(loc) << "TSTORE MAT lowering TODO for vpto backend"; + return failure(); +} + +} // namespace + +FailureOr +createLoopScopeRegion(Location loc, const VPTOLoopScopeContract &contract, + PatternRewriter &rewriter) { + if (contract.kind == VPTOLoopScopeKind::None) + return failure(); + if (contract.kind != VPTOLoopScopeKind::AIVVectorScope) + return failure(); + + auto vecScope = rewriter.create(loc); + vecScope.getBody().push_back(new Block()); + return vecScope; +} + +void set_loop2_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder) { + copyOp->setAttr("pto.set_loop2_stride_outtoub", + builder.getI64IntegerAttr( + packLoopStrideConfig(dstStride, srcStride))); +} + +void set_loop1_stride_outtoub(Operation *copyOp, int64_t dstStride, + int64_t srcStride, Builder &builder) { + copyOp->setAttr("pto.set_loop1_stride_outtoub", + builder.getI64IntegerAttr( + packLoopStrideConfig(dstStride, srcStride))); +} + +void set_loop_size_outtoub(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder) { + copyOp->setAttr("pto.set_loop_size_outtoub", + builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); +} + +void set_loop2_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder) { + copyOp->setAttr("pto.set_loop2_stride_ubtoout", + builder.getI64IntegerAttr( + packLoopStrideConfig(srcStride, dstStride))); +} + +void set_loop1_stride_ubtoout(Operation *copyOp, int64_t srcStride, + int64_t dstStride, Builder &builder) { + copyOp->setAttr("pto.set_loop1_stride_ubtoout", + builder.getI64IntegerAttr( + packLoopStrideConfig(srcStride, dstStride))); +} + +void set_loop_size_ubtoout(Operation *copyOp, int64_t loop2, int64_t loop1, + Builder &builder) { + copyOp->setAttr("pto.set_loop_size_ubtoout", + builder.getI64IntegerAttr(packLoopSizeConfig(loop2, loop1))); +} + +LogicalResult programCopyGmToUbLoops(Operation *copyOp, + const VPTOLoadContract &contract, + Builder &builder) { + SmallVector globalShape; + SmallVector globalStride; + int64_t nBurst = 0, lenBurst = 0, gmStrideBytes = 0, ubStrideBytes = 0; + int64_t loop1Size = 0, loop2Size = 0; + int64_t loop1DstStrideBytes = 0, loop1SrcStrideBytes = 0; + int64_t loop2DstStrideBytes = 0, loop2SrcStrideBytes = 0; + if (failed(deriveVecNDTransferConfig(contract.sourceShape, contract.sourceStrides, + contract.tileLayout, contract.elementType, + contract.validRows, contract.validCols, + globalShape, globalStride, nBurst, lenBurst, + gmStrideBytes, ubStrideBytes, loop1Size, + loop2Size, loop1DstStrideBytes, + loop1SrcStrideBytes, loop2DstStrideBytes, + loop2SrcStrideBytes))) + return failure(); + + set_loop2_stride_outtoub(copyOp, loop2DstStrideBytes, loop2SrcStrideBytes, builder); + set_loop1_stride_outtoub(copyOp, loop1DstStrideBytes, loop1SrcStrideBytes, builder); + set_loop_size_outtoub(copyOp, loop2Size, loop1Size, builder); + return success(); +} + +LogicalResult programCopyUbToGmLoops(Operation *copyOp, + const VPTOStoreContract &contract, + Builder &builder) { + SmallVector globalShape; + SmallVector globalStride; + int64_t nBurst = 0, lenBurst = 0, burstDstStrideBytes = 0, burstSrcStrideBytes = 0; + int64_t loop1Size = 0, loop2Size = 0; + int64_t loop1SrcStrideBytes = 0, loop1DstStrideBytes = 0; + int64_t loop2SrcStrideBytes = 0, loop2DstStrideBytes = 0; + if (failed(deriveVecNDTransferConfig(contract.destinationShape, + contract.destinationStrides, + "row_major", contract.elementType, + contract.validRows, contract.validCols, + globalShape, globalStride, nBurst, lenBurst, + burstDstStrideBytes, burstSrcStrideBytes, + loop1Size, loop2Size, loop1SrcStrideBytes, + loop1DstStrideBytes, loop2SrcStrideBytes, + loop2DstStrideBytes))) + return failure(); + + set_loop_size_ubtoout(copyOp, loop2Size, loop1Size, builder); + set_loop1_stride_ubtoout(copyOp, loop1SrcStrideBytes, loop1DstStrideBytes, builder); + set_loop2_stride_ubtoout(copyOp, loop2SrcStrideBytes, loop2DstStrideBytes, builder); + return success(); +} + +int64_t deriveStaticRowStride(Value value) { + StringRef layout = deriveTileLayout(value); + if (layout == "col_major") + return 1; + + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (shape.size() >= 2) + return shape[shape.size() - 1]; + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (shape.size() >= 2) + return shape[shape.size() - 1]; + } + return ShapedType::kDynamic; +} + +int64_t deriveStaticShapeDim(Value value, unsigned dim) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (dim < shape.size()) + return shape[dim]; + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (dim < shape.size()) + return shape[dim]; + } + return ShapedType::kDynamic; +} + +int64_t deriveStaticTileCols(Value value) { + if (auto tileType = dyn_cast(value.getType())) { + ArrayRef shape = tileType.getShape(); + if (!shape.empty()) + return shape.back(); + } + if (auto shapedType = dyn_cast(value.getType())) { + ArrayRef shape = shapedType.getShape(); + if (!shape.empty()) + return shape.back(); + } + return ShapedType::kDynamic; +} + +Value buildFullWidthColsCondition(ArrayRef tileCols, + Value validColsValue, + PatternRewriter &rewriter, Location loc) { + Value condition; + for (int64_t tileCol : tileCols) { + if (tileCol == ShapedType::kDynamic) + return {}; + Value tileColValue = rewriter.create(loc, tileCol); + Value isFullWidth = rewriter.create( + loc, arith::CmpIPredicate::eq, validColsValue, tileColValue); + condition = condition ? rewriter.create(loc, condition, isFullWidth) + : isFullWidth; + } + return condition; +} + +Value buildMinIndexValue(PatternRewriter &rewriter, Location loc, Value lhs, + Value rhs) { + auto lhsLtRhs = rewriter.create(loc, arith::CmpIPredicate::slt, + lhs, rhs); + return rewriter.create(loc, lhsLtRhs, lhs, rhs); +} + +struct PredicateMaterialization { + Value mask; + Value nextScalar; +}; + +PredicateMaterialization buildPredicateForLaneCount(PatternRewriter &rewriter, + Location loc, + Type elementType, + Value laneCount) { + auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); + Value laneCountI32 = laneCount; + if (laneCount.getType().isIndex()) { + laneCountI32 = + rewriter.create(loc, rewriter.getI32Type(), laneCount); + } else if (auto intType = dyn_cast(laneCount.getType())) { + if (intType.getWidth() < 32) + laneCountI32 = rewriter.create(loc, rewriter.getI32Type(), laneCount); + else if (intType.getWidth() > 32) + laneCountI32 = + rewriter.create(loc, rewriter.getI32Type(), laneCount); + } + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 8) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + if (bitWidth == 16) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + if (bitWidth == 32) { + auto plt = rewriter.create(loc, maskType, rewriter.getI32Type(), + laneCountI32); + return {plt.getMask(), plt.getScalarOut()}; + } + llvm_unreachable("unsupported element type for predicate lane-count lowering"); +} + +Value buildPredicateMaskForLaneCount(PatternRewriter &rewriter, Location loc, + Type elementType, Value laneCount) { + return buildPredicateForLaneCount(rewriter, loc, elementType, laneCount).mask; +} + +Value buildAllPredicateMask(PatternRewriter &rewriter, Location loc, + Type elementType) { + auto maskType = getVPTOMaskTypeForElementType(rewriter.getContext(), elementType); + StringAttr allPattern = rewriter.getStringAttr("PAT_ALL"); + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 8) + return rewriter.create(loc, maskType, allPattern).getResult(); + if (bitWidth == 16) + return rewriter.create(loc, maskType, allPattern).getResult(); + if (bitWidth == 32) + return rewriter.create(loc, maskType, allPattern).getResult(); + llvm_unreachable("unsupported element type for full predicate mask lowering"); +} + +LogicalResult buildMaskedVectorStore(PatternRewriter &rewriter, Location loc, + Value value, Value dstBuffer, + Value dstOffset, Value activeLanes, + int64_t vectorWidth) { + auto vecType = cast(value.getType()); + Value mask = buildPredicateMaskForLaneCount(rewriter, loc, + vecType.getElementType(), + activeLanes); + rewriter.create(loc, value, dstBuffer, dstOffset, StringAttr(), + mask); + return success(); +} + +Attribute buildRowReduceInitValue(Type elementType, StringRef family, + Builder &builder) { + if (!isa(elementType)) + return {}; + + if (family == "rowsum") + return builder.getFloatAttr(elementType, 0.0); + + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + bool negative = family == "rowmax"; + return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); +} + +Attribute buildPartPadValue(Type elementType, StringRef family, Builder &builder) { + if (family == "partadd") + return builder.getZeroAttr(elementType); + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + bool negative = family == "partmax"; + return builder.getFloatAttr(elementType, llvm::APFloat::getInf(semantics, negative)); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + if (intType.isUnsigned()) { + if (family == "partmax") + return builder.getIntegerAttr(elementType, 0); + return builder.getIntegerAttr(elementType, llvm::APInt::getAllOnes(width)); + } + if (family == "partmax") + return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMinValue(width)); + return builder.getIntegerAttr(elementType, llvm::APInt::getSignedMaxValue(width)); + } + return {}; +} + +Attribute buildFillPadValue(Type elementType, PadValueAttr padAttr, Builder &builder) { + if (!padAttr) + return {}; + + switch (padAttr.getValue()) { + case PadValue::Null: + return {}; + case PadValue::Zero: + return builder.getZeroAttr(elementType); + case PadValue::Max: + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + return builder.getFloatAttr(elementType, + llvm::APFloat::getLargest(semantics)); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + return intType.isUnsigned() + ? builder.getIntegerAttr(elementType, + llvm::APInt::getMaxValue(width)) + : builder.getIntegerAttr(elementType, + llvm::APInt::getSignedMaxValue(width)); + } + return {}; + case PadValue::Min: + if (isa(elementType)) { + const llvm::fltSemantics &semantics = [&]() -> const llvm::fltSemantics & { + if (elementType.isF16()) + return llvm::APFloat::IEEEhalf(); + if (elementType.isBF16()) + return llvm::APFloat::BFloat(); + return llvm::APFloat::IEEEsingle(); + }(); + auto value = llvm::APFloat::getLargest(semantics); + value.changeSign(); + return builder.getFloatAttr(elementType, value); + } + if (auto intType = dyn_cast(elementType)) { + unsigned width = intType.getWidth(); + return intType.isUnsigned() + ? builder.getIntegerAttr(elementType, llvm::APInt(width, 0)) + : builder.getIntegerAttr(elementType, + llvm::APInt::getSignedMinValue(width)); + } + return {}; + } + return {}; +} + +LogicalResult buildRowReduceVecScope(StringRef family, + const VPTORowReduceContract &contract, + VPTOLoweringStrategy strategy, Value src, + Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO row-reduce element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for row-reduce lowering"; + + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return emitError(loc) << family << " lowering currently requires static valid rows and cols"; + + int64_t srcRowStride = deriveStaticRowStride(src); + int64_t dstRowStride = deriveStaticRowStride(dst); + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) + return emitError(loc) << family << " lowering requires static row strides"; + + Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << family << " lowering supports only f16 and f32 element types"; + + auto getRowReduceStoreDist = [&]() -> StringAttr { + if (contract.elementType.isF16() || contract.elementType.isBF16()) + return rewriter.getStringAttr("1PT"); + if (contract.elementType.isF32()) + return rewriter.getStringAttr("1PT"); + return {}; + }; + StringAttr storeDist = getRowReduceStoreDist(); + if (!storeDist) + return emitError(loc) << family << " lowering supports only f16 and f32 row-reduce stores"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, contract.validRows); + Value srcRowStrideValue = rewriter.create(loc, srcRowStride); + Value dstRowStrideValue = rewriter.create(loc, dstRowStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + Value dstPredicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, c1); + Value validColsValue = + rewriter.create(loc, contract.validCols); + + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto rowLoop = + rewriter.create(loc, c0, rowsUpper, c1, ValueRange{dstBuffer}); + + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value dstPtr = rowLoop.getRegionIterArgs().front(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value srcPtr = + adjustPointerByElemOffset(srcBuffer, rowBase, getElementByteSize(contract.elementType), + rewriter, loc); + Value acc = rewriter.create(loc, vecType, initScalar); + Value remainingCols = rewriter.create( + loc, contract.validCols, 32); + for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { + auto predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remainingCols); + Value srcPredicate = predicateState.mask; + auto srcVecOp = rewriter.create( + loc, TypeRange{vecType, srcPtr.getType()}, srcPtr, vectorWidthValue, + rewriter.getStringAttr("NORM")); + Value srcVec = srcVecOp.getResult(); + srcPtr = srcVecOp.getUpdatedSource(); + + Value reduced; + if (family == "rowsum") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmax") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmin") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else + return emitError(loc) << "unsupported VPTO row-reduce family: " << family; + + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "rowsum") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else if (family == "rowmax") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + remainingCols = predicateState.nextScalar; + } + + auto storeOp = rewriter.create(loc, dstPtr.getType(), acc, dstPtr, + dstRowStrideValue, storeDist, + dstPredicate); + Value nextDstPtr = storeOp.getUpdatedDestination(); + rewriter.create(loc, nextDstPtr); + return success(); + } + + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value acc = rewriter.create(loc, vecType, initScalar); + for (int64_t repeatIndex = 0; repeatIndex < repeatTimes; ++repeatIndex) { + Value repeat = rewriter.create(loc, repeatIndex); + Value repeatBase = + rewriter.create(loc, repeat, vectorWidthValue); + Value srcOffset = + rewriter.create(loc, rowBase, repeatBase); + Value remainingCols = + rewriter.create(loc, validColsValue, repeatBase); + Value srcPredicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, remainingCols); + Value srcVec = + rewriter.create(loc, vecType, srcBuffer, srcOffset, + StringAttr()) + .getResult(); + + Value reduced; + if (family == "rowsum") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmax") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else if (family == "rowmin") + reduced = rewriter.create(loc, vecType, srcVec, srcPredicate); + else + return emitError(loc) << "unsupported VPTO row-reduce family: " << family; + + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "rowsum") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else if (family == "rowmax") + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + else + acc = rewriter.create(loc, vecType, acc, reduced, fullMask); + } + + Value dstOffset = rewriter.create(loc, row, dstRowStrideValue); + rewriter.create(loc, acc, dstBuffer, dstOffset, storeDist, + dstPredicate); + return success(); +} + +LogicalResult buildColReduceVecScope(StringRef family, + const VPTOColReduceContract &contract, + Value src, Value dst, Value tmp, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO col-reduce element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for col-reduce lowering"; + + Value tmpBuffer; + if (contract.isBinary) { + tmpBuffer = materializeBufferPointer(tmp, contract.elementType, getMemorySpace(tmp), + rewriter, loc); + if (!tmpBuffer) + return emitError(loc) << "binary colsum lowering requires pointer-backed tmp tile"; + } + + int64_t srcRowStride = deriveStaticRowStride(src); + int64_t dstRowStride = deriveStaticRowStride(dst); + int64_t tmpRowStride = + contract.isBinary ? deriveStaticRowStride(tmp) : ShapedType::kDynamic; + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic || + (contract.isBinary && tmpRowStride == ShapedType::kDynamic)) + return emitError(loc) << family << " lowering requires static row strides"; + + Attribute initValue = buildRowReduceInitValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << family << " lowering supports only f16 and f32 element types"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.validCols, vectorWidth); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value rowUpper = rewriter.create(loc, contract.validRows); + Value srcRowStrideValue = rewriter.create(loc, srcRowStride); + Value dstRowStrideValue = rewriter.create(loc, dstRowStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value chunk = chunkLoop.getInductionVar(); + Value chunkOffset = rewriter.create(loc, chunk, vectorWidthValue); + + if (!contract.isBinary) { + Value firstRowOffset = chunkOffset; + Value acc0 = + rewriter.create(loc, vecType, srcBuffer, firstRowOffset, StringAttr()).getResult(); + auto rowLoop = rewriter.create(loc, c1, rowUpper, c1, ValueRange{acc0}); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value acc = rowLoop.getRegionIterArgs().front(); + Value rowBase = rewriter.create(loc, row, srcRowStrideValue); + Value srcOffset = rewriter.create(loc, rowBase, chunkOffset); + Value srcVec = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); + Value nextAcc; + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + if (family == "colmax") + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + else if (family == "colmin") + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + else + nextAcc = rewriter.create(loc, vecType, acc, srcVec, fullMask); + rewriter.create(loc, nextAcc); + + rewriter.setInsertionPointAfter(rowLoop); + Value dstOffset = chunkOffset; + rewriter.create( + loc, rowLoop.getResult(0), dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, contract.elementType)); + return success(); + } + + Value tmpRowStrideValue = rewriter.create(loc, tmpRowStride); + auto reducePair = [&](Value lhs, Value rhs) -> Value { + return rewriter.create( + loc, vecType, lhs, rhs, buildAllPredicateMask(rewriter, loc, contract.elementType)) + .getResult(); + }; + + int64_t nLoopStatic = contract.validRows / 2; + bool remainStatic = (contract.validRows % 2) != 0; + Value pairUpper = rewriter.create(loc, nLoopStatic); + auto pairLoop = rewriter.create(loc, c0, pairUpper, c1); + { + OpBuilder::InsertionGuard pairGuard(rewriter); + rewriter.setInsertionPointToStart(pairLoop.getBody()); + Value pair = pairLoop.getInductionVar(); + Value row0 = rewriter.create( + loc, rewriter.create(loc, pair, rewriter.create(loc, 2)), + srcRowStrideValue); + Value row1 = rewriter.create( + loc, rewriter.create(loc, + rewriter.create(loc, pair, rewriter.create(loc, 2)), + c1), + srcRowStrideValue); + Value src0Offset = rewriter.create(loc, row0, chunkOffset); + Value src1Offset = rewriter.create(loc, row1, chunkOffset); + Value lhs = rewriter.create(loc, vecType, srcBuffer, src0Offset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, srcBuffer, src1Offset, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + Value tmpOffset = rewriter.create(loc, pair, tmpRowStrideValue); + rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + + if (remainStatic && nLoopStatic > 0) { + Value lastRowOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, contract.validRows - 1), + srcRowStrideValue), + chunkOffset); + Value tmpOffset = rewriter.create( + loc, rewriter.create(loc, nLoopStatic - 1), tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, srcBuffer, lastRowOffset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, tmpOffset, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + rewriter.create(loc, sum, tmpBuffer, tmpOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + + int64_t currentRows = nLoopStatic; + while (currentRows > 1) { + int64_t nextRows = currentRows / 2; + bool remain = (currentRows % 2) != 0; + Value nextUpper = rewriter.create(loc, nextRows); + auto foldLoop = rewriter.create(loc, c0, nextUpper, c1); + OpBuilder::InsertionGuard foldGuard(rewriter); + rewriter.setInsertionPointToStart(foldLoop.getBody()); + Value pair = foldLoop.getInductionVar(); + Value idx2 = rewriter.create( + loc, pair, rewriter.create(loc, 2)); + Value idx2p1 = rewriter.create(loc, idx2, c1); + Value lhsOff = rewriter.create(loc, idx2, tmpRowStrideValue); + Value rhsOff = rewriter.create(loc, idx2p1, tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + Value outOff = rewriter.create(loc, pair, tmpRowStrideValue); + rewriter.create(loc, sum, tmpBuffer, outOff, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + + rewriter.setInsertionPointAfter(foldLoop); + if (remain && nextRows > 0) { + Value lhsOff = rewriter.create( + loc, rewriter.create(loc, nextRows - 1), tmpRowStrideValue); + Value rhsOff = rewriter.create( + loc, rewriter.create(loc, 2 * nextRows), tmpRowStrideValue); + Value lhs = rewriter.create(loc, vecType, tmpBuffer, lhsOff, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, tmpBuffer, rhsOff, StringAttr()).getResult(); + Value sum = reducePair(lhs, rhs); + rewriter.create(loc, sum, tmpBuffer, lhsOff, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + } + currentRows = nextRows; + } + + Value finalVec; + if (currentRows == 0) { + finalVec = rewriter.create(loc, vecType, initScalar).getResult(); + } else { + finalVec = rewriter.create(loc, vecType, tmpBuffer, c0, StringAttr()).getResult(); + } + Value dstOffset = chunkOffset; + rewriter.create(loc, finalVec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + return success(); +} + +LogicalResult buildPartFill(StringRef family, const VPTOPartContract &contract, + Value dstBuffer, pto::VRegType vecType, + int64_t dstStride, PatternRewriter &rewriter, + Location loc) { + Attribute initValue = buildPartPadValue(contract.elementType, family, rewriter); + if (!initValue) + return emitError(loc) << "unsupported pad value for " << family; + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(contract.dstValidCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, contract.dstValidRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value initScalar = rewriter.create(loc, cast(initValue)); + Value initVec = rewriter.create(loc, vecType, initScalar); + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, dstStrideValue); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value dstOffset = rewriter.create(loc, rowBase, chunkBase); + rewriter.create(loc, initVec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartCopyRegion(Value srcBuffer, Value dstBuffer, pto::VRegType vecType, + int64_t srcStride, int64_t dstStride, + int64_t startRow, int64_t validRows, + int64_t validCols, PatternRewriter &rewriter, + Location loc) { + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, validRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + Value startRowValue = rewriter.create(loc, startRow); + auto rowLoop = rewriter.create(loc, startRowValue, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value rowSrc = rewriter.create(loc, row, srcStrideValue); + Value rowDst = rewriter.create(loc, row, dstStrideValue); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value srcOffset = rewriter.create(loc, rowSrc, chunkBase); + Value dstOffset = rewriter.create(loc, rowDst, chunkBase); + Value vec = rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()).getResult(); + rewriter.create(loc, vec, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartBinaryRegion(StringRef family, Value src0Buffer, Value src1Buffer, + Value dstBuffer, pto::VRegType vecType, + int64_t src0Stride, int64_t src1Stride, + int64_t dstStride, int64_t validRows, + int64_t validCols, PatternRewriter &rewriter, + Location loc) { + int64_t vectorWidth = vecType.getElementCount(); + int64_t repeatTimes = llvm::divideCeil(validCols, vectorWidth); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value rowsUpper = rewriter.create(loc, validRows); + Value repeatUpper = rewriter.create(loc, repeatTimes); + Value src0StrideValue = rewriter.create(loc, src0Stride); + Value src1StrideValue = rewriter.create(loc, src1Stride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value vectorWidthValue = rewriter.create(loc, vectorWidth); + auto rowLoop = rewriter.create(loc, c0, rowsUpper, c1); + OpBuilder::InsertionGuard rowGuard(rewriter); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value chunk = chunkLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, chunk, vectorWidthValue); + Value rowSrc0 = rewriter.create(loc, row, src0StrideValue); + Value rowSrc1 = rewriter.create(loc, row, src1StrideValue); + Value rowDst = rewriter.create(loc, row, dstStrideValue); + Value src0Offset = rewriter.create(loc, rowSrc0, chunkBase); + Value src1Offset = rewriter.create(loc, rowSrc1, chunkBase); + Value dstOffset = rewriter.create(loc, rowDst, chunkBase); + Value lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()).getResult(); + Value rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()).getResult(); + Value fullMask = buildAllPredicateMask(rewriter, loc, vecType.getElementType()); + Value out; + if (family == "partadd") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else if (family == "partmax") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else if (family == "partmin") + out = rewriter.create(loc, vecType, lhs, rhs, fullMask); + else + return emitError(loc) << "unsupported part family: " << family; + rewriter.create(loc, out, dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, loc, + vecType.getElementType())); + rewriter.setInsertionPointAfter(chunkLoop); + return success(); +} + +LogicalResult buildPartVecScope(StringRef family, const VPTOPartContract &contract, + Value src0, Value src1, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO part element type"; + Value src0Buffer = materializeBufferLikeAddress(src0, contract.elementType, + getMemorySpace(src0), rewriter, loc); + Value src1Buffer = materializeBufferLikeAddress(src1, contract.elementType, + getMemorySpace(src1), rewriter, loc); + Value dstBuffer = materializeBufferLikeAddress(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for part lowering"; + int64_t src0Stride = deriveStaticRowStride(src0); + int64_t src1Stride = deriveStaticRowStride(src1); + int64_t dstStride = deriveStaticRowStride(dst); + if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || + dstStride == ShapedType::kDynamic) + return emitError(loc) << family << " lowering requires static row strides"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + auto condSrc0EqDst = contract.src0ValidRows == contract.dstValidRows && + contract.src0ValidCols == contract.dstValidCols; + auto condSrc0RowLtDst = contract.src0ValidRows < contract.dstValidRows && + contract.src0ValidCols == contract.dstValidCols; + auto condSrc0ColLtDst = contract.src0ValidRows <= contract.dstValidRows && + contract.src0ValidCols < contract.dstValidCols; + auto condSrc1EqDst = contract.src1ValidRows == contract.dstValidRows && + contract.src1ValidCols == contract.dstValidCols; + auto condSrc1RowLtDst = contract.src1ValidRows < contract.dstValidRows && + contract.src1ValidCols == contract.dstValidCols; + auto condSrc1ColLtDst = contract.src1ValidRows <= contract.dstValidRows && + contract.src1ValidCols < contract.dstValidCols; + + if (family == "partadd") { + if (condSrc0EqDst && condSrc1EqDst) + return buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.dstValidRows, contract.dstValidCols, + rewriter, loc); + if (condSrc0ColLtDst && condSrc1EqDst) { + if (failed(buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, + 0, contract.src1ValidRows, contract.dstValidCols, + rewriter, loc))) + return failure(); + if (contract.src0ValidCols != 0) + return buildPartBinaryRegion(family, src0Buffer, dstBuffer, dstBuffer, vecType, + src0Stride, dstStride, dstStride, + contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc); + return success(); + } + if (condSrc0RowLtDst && condSrc1EqDst) { + if (contract.src0ValidRows != 0 && + failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc))) + return failure(); + return buildPartCopyRegion(src1Buffer, dstBuffer, vecType, src1Stride, dstStride, + contract.src0ValidRows, contract.src1ValidRows, + contract.dstValidCols, rewriter, loc); + } + if (condSrc1ColLtDst && condSrc0EqDst) { + if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + 0, contract.src0ValidRows, contract.dstValidCols, + rewriter, loc))) + return failure(); + if (contract.src1ValidCols != 0) + return buildPartBinaryRegion(family, src1Buffer, dstBuffer, dstBuffer, vecType, + src1Stride, dstStride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc); + return success(); + } + if (condSrc1RowLtDst && condSrc0EqDst) { + if (contract.src1ValidRows != 0 && + failed(buildPartBinaryRegion(family, src0Buffer, src1Buffer, dstBuffer, vecType, + src0Stride, src1Stride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc))) + return failure(); + return buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + contract.src1ValidRows, contract.src0ValidRows, + contract.dstValidCols, rewriter, loc); + } + return emitError(loc) << "partadd lowering only supports PTO-covered destination-equality/extension cases"; + } + + bool condDstGeSrc = contract.src0ValidRows <= contract.dstValidRows && + contract.src0ValidCols <= contract.dstValidCols && + contract.src1ValidRows <= contract.dstValidRows && + contract.src1ValidCols <= contract.dstValidCols; + if (!condDstGeSrc) + return emitError(loc) << family << " lowering only supports dst >= src0/src1 shape relation"; + if (failed(buildPartFill(family, contract, dstBuffer, vecType, dstStride, rewriter, loc))) + return failure(); + if (failed(buildPartCopyRegion(src0Buffer, dstBuffer, vecType, src0Stride, dstStride, + 0, contract.src0ValidRows, contract.src0ValidCols, + rewriter, loc))) + return failure(); + return buildPartBinaryRegion(family, dstBuffer, src1Buffer, dstBuffer, vecType, + dstStride, src1Stride, dstStride, + contract.src1ValidRows, contract.src1ValidCols, + rewriter, loc); +} + +LogicalResult buildUnaryVecScope(StringRef family, + const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, Value src, + Value dst, PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO unary element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for unary lowering"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, validRowsValue, validColsValue); + deriveValidShape(dst, validRows, validCols); + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << "unary lowering requires valid rows and cols"; + + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "unary lowering requires static row strides and cols"; + + auto buildUnaryValue = [&](Value loaded, Value predicate) -> FailureOr { + if (family == "abs") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "exp") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "log") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "sqrt") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "relu") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + if (family == "not") + return rewriter.create(loc, vecType, loaded, predicate).getResult(); + return failure(); + }; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), + totalElementsValue); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + Value fullWidthCond = + buildFullWidthColsCondition({srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "unary lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + scf::ForOp chunkLoop; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + } else { + chunkLoop = rewriter.create(loc, c0, totalElementsValue, + vectorStepValue, + ValueRange{scalarInit}); + } + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs().back(); + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value loadBase = srcBuffer; + Value storeBase = dstBuffer; + Value loadOffset = chunkLoop.getInductionVar(); + Value storeOffset = chunkLoop.getInductionVar(); + if (strategy == VPTOLoweringStrategy::PostUpdate) { + loadBase = chunkLoop.getRegionIterArgs()[0]; + storeBase = chunkLoop.getRegionIterArgs()[1]; + loadOffset = vectorStepValue; + storeOffset = vectorStepValue; + } + Value loaded; + Value nextSrc = {}; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vlds = rewriter.create(loc, vecType, loadBase.getType(), + loadBase, loadOffset, StringAttr()); + loaded = vlds.getResult(); + nextSrc = vlds.getUpdatedSource(); + } else { + auto vlds = + rewriter.create(loc, vecType, loadBase, loadOffset, StringAttr()); + loaded = vlds.getResult(); + } + FailureOr computed = buildUnaryValue(loaded, predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO unary family: " << family; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vsts = rewriter.create(loc, storeBase.getType(), *computed, + storeBase, storeOffset, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } else { + rewriter.create(loc, *computed, storeBase, storeOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value remaining = repeatLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + FailureOr computed = + buildUnaryValue(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO unary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + rewriter.setInsertionPointAfter(ifOp); + + return success(); +} + +LogicalResult buildBinaryVecScope(StringRef family, + const VPTOBinaryContract &contract, + VPTOLoweringStrategy strategy, Value src0, + Value src1, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO binary element type"; + + Value src0Buffer = materializeBufferPointer(src0, contract.elementType, + getMemorySpace(src0), rewriter, loc); + Value src1Buffer = materializeBufferPointer(src1, contract.elementType, + getMemorySpace(src1), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffers for binary lowering"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = contract.validRowsValue; + Value validColsValue = contract.validColsValue; + int64_t validRows = contract.validRows; + int64_t validCols = contract.validCols; + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << "binary lowering requires valid rows and cols"; + + int64_t src0Stride = deriveStaticRowStride(src0); + int64_t src1Stride = deriveStaticRowStride(src1); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t src0Cols = deriveStaticTileCols(src0); + int64_t src1Cols = deriveStaticTileCols(src1); + int64_t dstCols = deriveStaticTileCols(dst); + if (src0Stride == ShapedType::kDynamic || src1Stride == ShapedType::kDynamic || + dstStride == ShapedType::kDynamic || src0Cols == ShapedType::kDynamic || + src1Cols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "binary lowering requires static row strides and cols"; + + auto buildBinaryValue = [&](Value lhs, Value rhs, Value mask) -> FailureOr { + if (family == "add") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "sub") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "mul") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "div") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "max") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "min") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "and") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "or") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + if (family == "xor") + return rewriter.create(loc, vecType, lhs, rhs, mask).getResult(); + return failure(); + }; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value src0StrideValue = rewriter.create(loc, src0Stride); + Value src1StrideValue = rewriter.create(loc, src1Stride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value scalarInit = rewriter.create(loc, rewriter.getI32Type(), + totalElementsValue); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + bool sameShapeLinearPath = src0Stride == dstStride && src1Stride == dstStride && + src0Cols == dstCols && src1Cols == dstCols; + Value fullWidthCond = buildFullWidthColsCondition( + {src0Cols, src1Cols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "binary lowering could not materialize full-width selector"; + Value use1DCond = sameShapeLinearPath ? fullWidthCond : Value(); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto emit1DBody = [&]() -> LogicalResult { + scf::ForOp chunkLoop; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{src0Buffer, src1Buffer, dstBuffer, scalarInit}); + } else { + chunkLoop = rewriter.create(loc, c0, totalElementsValue, + vectorStepValue, + ValueRange{scalarInit}); + } + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs().back(); + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value lhsBase = src0Buffer; + Value rhsBase = src1Buffer; + Value dstBase = dstBuffer; + Value loadOffset = chunkLoop.getInductionVar(); + Value storeOffset = chunkLoop.getInductionVar(); + if (strategy == VPTOLoweringStrategy::PostUpdate) { + lhsBase = chunkLoop.getRegionIterArgs()[0]; + rhsBase = chunkLoop.getRegionIterArgs()[1]; + dstBase = chunkLoop.getRegionIterArgs()[2]; + loadOffset = vectorStepValue; + storeOffset = vectorStepValue; + } + Value lhsValue; + Value rhsValue; + Value nextSrc0 = {}; + Value nextSrc1 = {}; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto lhs = rewriter.create(loc, vecType, lhsBase.getType(), + lhsBase, loadOffset, StringAttr()); + auto rhs = rewriter.create(loc, vecType, rhsBase.getType(), + rhsBase, loadOffset, StringAttr()); + lhsValue = lhs.getResult(); + rhsValue = rhs.getResult(); + nextSrc0 = lhs.getUpdatedSource(); + nextSrc1 = rhs.getUpdatedSource(); + } else { + auto lhs = + rewriter.create(loc, vecType, lhsBase, loadOffset, StringAttr()); + auto rhs = + rewriter.create(loc, vecType, rhsBase, loadOffset, StringAttr()); + lhsValue = lhs.getResult(); + rhsValue = rhs.getResult(); + } + FailureOr computed = buildBinaryValue(lhsValue, rhsValue, predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + if (strategy == VPTOLoweringStrategy::PostUpdate) { + auto vsts = rewriter.create(loc, dstBase.getType(), *computed, + dstBase, storeOffset, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, + ValueRange{nextSrc0, nextSrc1, nextDst, predicateState.nextScalar}); + } else { + rewriter.create(loc, *computed, dstBase, storeOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + } + return success(); + }; + + auto emit2DBody = [&]() -> LogicalResult { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value src0RowBase = rewriter.create(loc, row, src0StrideValue); + Value src1RowBase = rewriter.create(loc, row, src1StrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value remaining = repeatLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); + Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + auto lhs = rewriter.create(loc, vecType, src0Buffer, src0Offset, + StringAttr()); + auto rhs = rewriter.create(loc, vecType, src1Buffer, src1Offset, + StringAttr()); + FailureOr computed = + buildBinaryValue(lhs.getResult(), rhs.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + return success(); + } + + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value chunkBase = + rewriter.create(loc, repeatLoop.getInductionVar(), vectorStepValue); + Value src0Offset = rewriter.create(loc, src0RowBase, chunkBase); + Value src1Offset = rewriter.create(loc, src1RowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value nextChunk = rewriter.create(loc, chunkBase, vectorStepValue); + Value exceeds = + rewriter.create(loc, arith::CmpIPredicate::sge, nextChunk, validColsValue); + Value tailCount = rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + rewriter.create(loc, exceeds, tailCount, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount(rewriter, loc, + contract.elementType, activeLanes); + auto lhs = + rewriter.create(loc, vecType, src0Buffer, src0Offset, StringAttr()); + auto rhs = + rewriter.create(loc, vecType, src1Buffer, src1Offset, StringAttr()); + FailureOr computed = buildBinaryValue(lhs.getResult(), rhs.getResult(), predicate); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO binary family: " << family; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicate); + return success(); + }; + + if (use1DCond) { + auto ifOp = rewriter.create(loc, TypeRange{}, use1DCond, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + if (failed(emit1DBody())) + return failure(); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + if (failed(emit2DBody())) + return failure(); + rewriter.setInsertionPointAfter(ifOp); + } else { + if (failed(emit2DBody())) + return failure(); + } + return success(); +} + +LogicalResult buildExpandScalarVecScope(const VPTOUnaryContract &contract, + Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO expands element type"; + + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!dstBuffer) + return emitError(loc) << "requires pointer-backed tile buffer for expands lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "expands lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t dstCols = deriveStaticTileCols(dst); + if (dstStride == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) << "expands lowering requires static destination row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = + buildFullWidthColsCondition({dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "expands lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value dstPtr = chunkLoop.getRegionIterArgs()[0]; + Value remaining = chunkLoop.getRegionIterArgs()[1]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + Value computed = + rewriter.create(loc, vecType, scalar, predicateState.mask, StringAttr()); + auto vsts = rewriter.create(loc, dstPtr.getType(), computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextDst, predicateState.nextScalar}); + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value rowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value dstOffset = rewriter.create(loc, rowBase, chunkBase); + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + Value computed = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + rewriter.create(loc, computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult buildScalarUnaryVecScope(StringRef family, + const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, + Value src, Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO scalar-unary element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for scalar-unary lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << family << " lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << family << " lowering requires static src/dst row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = buildFullWidthColsCondition( + {srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << family << " lowering could not materialize full-width selector"; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + auto emitComputed = [&](Value loadedVec, Value predicate) -> FailureOr { + if (family == "adds") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "maxs") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "mins") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "muls") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + if (family == "lrelu") + return rewriter.create(loc, vecType, loadedVec, scalar, predicate).getResult(); + return failure(); + }; + + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); + FailureOr computed = emitComputed(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), + predicate); + } else { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcPtr = chunkLoop.getRegionIterArgs()[0]; + Value dstPtr = chunkLoop.getRegionIterArgs()[1]; + Value remaining = chunkLoop.getRegionIterArgs()[2]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + vectorStepValue, StringAttr()); + FailureOr computed = emitComputed(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextSrc = loaded.getUpdatedSource(); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value predicate; + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); + } else { + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + } + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + Value computed; + if (family == "adds") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "maxs") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "mins") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "muls") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else if (family == "lrelu") + computed = rewriter.create(loc, vecType, loaded.getResult(), scalar, predicate); + else + return emitError(loc) << "unsupported VPTO scalar-unary family: " << family; + rewriter.create(loc, computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult buildScalarBitwiseVecScope(StringRef family, + const VPTOUnaryContract &contract, + Value src, Value scalar, Value dst, + PatternRewriter &rewriter, + Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO scalar-bitwise element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for scalar-bitwise lowering"; + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, validRowsValue, validColsValue); + deriveValidShape(dst, validRows, validCols); + if (failed(resolveExecutionValidShape(dst, validRowsValue, validColsValue, validRows, + validCols, rewriter, loc))) + return emitError(loc) << family << " lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value vectorWidthValue = + rewriter.create(loc, vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorWidthValue); + Value predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, activeLanes); + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + auto loaded = rewriter.create(loc, vecType, srcBuffer, offset, + StringAttr()); + + Value computed; + if (family == "ands") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else if (family == "ors") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else if (family == "xors") + computed = + rewriter.create(loc, vecType, loaded.getResult(), scalarVec, predicate); + else + return emitError(loc) << "unsupported VPTO scalar-bitwise family: " << family; + rewriter.create(loc, computed, dstBuffer, offset, StringAttr(), + predicate); + return success(); +} + +static bool isVPTOShapedLikeValue(Value value) { + Type type = value.getType(); + return isa(type); +} + +LogicalResult buildScalarDivVecScope(const VPTOUnaryContract &contract, + VPTOLoweringStrategy strategy, + Value src, Value scalar, Value dst, + bool scalarFirst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO divs element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for divs lowering"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "divs lowering requires valid rows and cols"; + + int64_t vectorWidth = vecType.getElementCount(); + int64_t srcStride = deriveStaticRowStride(src); + int64_t dstStride = deriveStaticRowStride(dst); + int64_t srcCols = deriveStaticTileCols(src); + int64_t dstCols = deriveStaticTileCols(dst); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << "divs lowering requires static src/dst row stride and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value totalElementsValue = + rewriter.create(loc, validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value srcStrideValue = rewriter.create(loc, srcStride); + Value dstStrideValue = rewriter.create(loc, dstStride); + Value fullWidthCond = buildFullWidthColsCondition( + {srcCols, dstCols}, validColsValue, rewriter, loc); + if (!fullWidthCond) + return emitError(loc) << "divs lowering could not materialize full-width selector"; + + auto buildDivValue = [&](Value loaded, Value predicate) -> FailureOr { + if (contract.elementType.isF32()) { + if (scalarFirst) { + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + return rewriter.create(loc, vecType, scalarVec, loaded, predicate) + .getResult(); + } + Value one = rewriter.create( + loc, contract.elementType, + rewriter.getFloatAttr(contract.elementType, 1.0)); + Value reciprocal = rewriter.create(loc, one, scalar); + return rewriter.create(loc, vecType, loaded, reciprocal, predicate).getResult(); + } + if (contract.elementType.isF16()) { + Value scalarVec = + rewriter.create(loc, vecType, scalar, predicate, StringAttr()); + return scalarFirst + ? rewriter.create(loc, vecType, scalarVec, loaded, predicate) + .getResult() + : rewriter.create(loc, vecType, loaded, scalarVec, predicate) + .getResult(); + } + return failure(); + }; + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto ifOp = rewriter.create(loc, TypeRange{}, fullWidthCond, + /*withElseRegion=*/true); + + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + { + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto chunkLoop = + rewriter.create(loc, c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value remaining = rewriter.create(loc, totalElementsValue, offset); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remaining, vectorStepValue); + Value predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, offset, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + rewriter.create(loc, *computed, dstBuffer, offset, StringAttr(), + predicate); + } else { + Value scalarInit = rewriter.create( + loc, rewriter.getI32Type(), totalElementsValue); + auto chunkLoop = rewriter.create( + loc, c0, totalElementsValue, vectorStepValue, + ValueRange{srcBuffer, dstBuffer, scalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcPtr = chunkLoop.getRegionIterArgs()[0]; + Value dstPtr = chunkLoop.getRegionIterArgs()[1]; + Value remaining = chunkLoop.getRegionIterArgs()[2]; + PredicateMaterialization predicateState = buildPredicateForLaneCount( + rewriter, loc, contract.elementType, remaining); + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + vectorStepValue, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicateState.mask); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + auto vsts = rewriter.create(loc, dstPtr.getType(), *computed, dstPtr, + vectorStepValue, StringAttr(), + predicateState.mask); + Value nextSrc = loaded.getUpdatedSource(); + Value nextDst = vsts.getUpdatedDestination(); + rewriter.create( + loc, ValueRange{nextSrc, nextDst, predicateState.nextScalar}); + } + } + + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + { + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcRowBase = rewriter.create(loc, row, srcStrideValue); + Value dstRowBase = rewriter.create(loc, row, dstStrideValue); + auto repeatLoop = rewriter.create(loc, c0, repeatUpper, c1); + rewriter.setInsertionPointToStart(repeatLoop.getBody()); + Value repeat = repeatLoop.getInductionVar(); + Value chunkBase = rewriter.create(loc, repeat, vectorStepValue); + Value srcOffset = rewriter.create(loc, srcRowBase, chunkBase); + Value dstOffset = rewriter.create(loc, dstRowBase, chunkBase); + Value predicate; + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + predicate = + buildPredicateMaskForLaneCount(rewriter, loc, contract.elementType, validColsValue); + } else { + Value remainingCols = + rewriter.create(loc, validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, loc, remainingCols, vectorStepValue); + predicate = buildPredicateMaskForLaneCount( + rewriter, loc, contract.elementType, activeLanes); + } + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + FailureOr computed = buildDivValue(loaded.getResult(), predicate); + if (failed(computed)) + return emitError(loc) + << "divs lowering currently supports only f16 and f32 element types"; + rewriter.create(loc, *computed, dstBuffer, dstOffset, + StringAttr(), predicate); + } + + rewriter.setInsertionPointAfter(ifOp); + return success(); +} + +LogicalResult checkExpandContract(Operation *op, + const VPTOExpandContract &contract) { + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family + << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || + (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family + << " lowering currently supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + auto isStatic = [](int64_t value) { return value != ShapedType::kDynamic; }; + if (!isStatic(contract.srcValidRows) || !isStatic(contract.srcValidCols) || + !isStatic(contract.dstValidRows) || !isStatic(contract.dstValidCols)) { + op->emitOpError() << contract.family + << " lowering currently requires static source and destination valid shapes"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult buildRowExpandVecScope(const VPTOExpandContract &contract, + VPTOLoweringStrategy strategy, Value src, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO rowexpand element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for rowexpand lowering"; + + auto [srcRows, srcCols] = getStaticTileRowsCols(src); + auto [dstRows, dstCols] = getStaticTileRowsCols(dst); + if (srcCols == ShapedType::kDynamic || dstCols == ShapedType::kDynamic || + srcRows == ShapedType::kDynamic || dstRows == ShapedType::kDynamic) + return emitError(loc) << "rowexpand lowering requires static physical tile shape"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = materializeIndexValue( + contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); + Value validColsValue = materializeIndexValue( + contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "rowexpand lowering requires valid rows and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value srcStrideValue = rewriter.create(loc, srcCols); + Value dstStrideValue = rewriter.create(loc, dstCols); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + Value rowScalarInit = rewriter.create(loc, rewriter.getI32Type(), + validColsValue); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + Value repeatUpper = rewriter.create(loc, validColsValue, + vectorStepValue); + if (strategy == VPTOLoweringStrategy::NoPostUpdate) { + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcOffset = rewriter.create(loc, row, srcStrideValue); + Value dstBase = rewriter.create(loc, row, dstStrideValue); + auto loaded = + rewriter.create(loc, vecType, srcBuffer, srcOffset, StringAttr()); + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + Value expanded = rewriter.create( + loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{rowScalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value remaining = chunkLoop.getRegionIterArgs()[0]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + Value chunkBase = + rewriter.create(loc, chunkLoop.getInductionVar(), vectorStepValue); + Value dstOffset = rewriter.create(loc, dstBase, chunkBase); + rewriter.create(loc, expanded, dstBuffer, dstOffset, StringAttr(), + predicateState.mask); + rewriter.create(loc, ValueRange{predicateState.nextScalar}); + return success(); + } + + auto rowLoop = + rewriter.create(loc, c0, validRowsValue, c1, ValueRange{srcBuffer, dstBuffer}); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value srcPtr = rowLoop.getRegionIterArgs()[0]; + Value dstPtr = rowLoop.getRegionIterArgs()[1]; + auto loaded = rewriter.create(loc, vecType, srcPtr.getType(), srcPtr, + srcStrideValue, StringAttr()); + Value fullMask = buildAllPredicateMask(rewriter, loc, contract.elementType); + Value expanded = rewriter.create( + loc, vecType, loaded.getResult(), fullMask, rewriter.getStringAttr("LOWEST")); + auto chunkLoop = rewriter.create(loc, c0, repeatUpper, c1, + ValueRange{dstPtr, rowScalarInit}); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value dstChunkPtr = chunkLoop.getRegionIterArgs()[0]; + Value remaining = chunkLoop.getRegionIterArgs()[1]; + PredicateMaterialization predicateState = + buildPredicateForLaneCount(rewriter, loc, contract.elementType, remaining); + auto vsts = rewriter.create(loc, dstChunkPtr.getType(), expanded, + dstChunkPtr, vectorStepValue, StringAttr(), + predicateState.mask); + Value nextDstChunkPtr = vsts.getUpdatedDestination(); + rewriter.create(loc, ValueRange{nextDstChunkPtr, predicateState.nextScalar}); + + rewriter.setInsertionPointAfter(chunkLoop); + Value rowAdvance = rewriter.create(loc, repeatUpper, vectorStepValue); + Value dstPad = rewriter.create(loc, dstStrideValue, rowAdvance); + Value nextDstPtr = + offsetBufferPointer(dstPtr, contract.elementType, dstPad, rewriter, loc); + Value nextSrcPtr = loaded.getUpdatedSource(); + rewriter.create(loc, ValueRange{nextSrcPtr, nextDstPtr}); + return success(); +} + +LogicalResult buildColExpandVecScope(const VPTOExpandContract &contract, + Value src, Value dst, + PatternRewriter &rewriter, Location loc) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << "unsupported VPTO colexpand element type"; + + Value srcBuffer = materializeBufferPointer(src, contract.elementType, + getMemorySpace(src), rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, contract.elementType, + getMemorySpace(dst), rewriter, loc); + if (!srcBuffer || !dstBuffer) + return emitError(loc) + << "requires pointer-backed tile buffers for colexpand lowering"; + + auto [dstRows, dstCols] = getStaticTileRowsCols(dst); + if (dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return emitError(loc) + << "colexpand lowering requires static physical destination tile shape"; + + int64_t vectorWidth = vecType.getElementCount(); + Value validRowsValue = materializeIndexValue( + contract.dstValidRowsValue, contract.dstValidRows, rewriter, loc); + Value validColsValue = materializeIndexValue( + contract.dstValidColsValue, contract.dstValidCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << "colexpand lowering requires valid rows and cols"; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value dstStrideValue = rewriter.create(loc, dstCols); + Value vectorStepValue = + rewriter.create(loc, vectorWidth); + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(loc, c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(loc, c0, validColsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value dstBase = + rewriter.create(loc, rowLoop.getInductionVar(), dstStrideValue); + Value dstOffset = + rewriter.create(loc, dstBase, chunkLoop.getInductionVar()); + auto loaded = rewriter.create( + loc, vecType, srcBuffer, chunkLoop.getInductionVar(), StringAttr()); + rewriter.create(loc, loaded.getResult(), dstBuffer, dstOffset, + StringAttr(), + buildAllPredicateMask(rewriter, loc, + contract.elementType)); + return success(); +} + +LogicalResult checkGenericUnaryContract(Operation *op, + const VPTOUnaryContract &contract, + Value dst, + function_ref typePredicate, + StringRef supportedTypeText) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + StringRef dstLayout = deriveTileLayout(dst); + VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); + + bool hasPrecheckFailure = false; + if (contract.tileDomain != VPTOTileDomain::Vec || dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires tile domain vec"; + hasPrecheckFailure = true; + } + if (contract.tileLayout != "row_major" || dstLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major tile layout"; + hasPrecheckFailure = true; + } + if (contract.validRows != ShapedType::kDynamic && + dstRows != ShapedType::kDynamic && dstRows > contract.validRows) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows not to exceed source"; + hasPrecheckFailure = true; + } + if (contract.validCols != ShapedType::kDynamic && + dstCols != ShapedType::kDynamic && dstCols > contract.validCols) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols not to exceed source"; + hasPrecheckFailure = true; + } + if (!contract.elementType || !typePredicate(contract.elementType)) { + op->emitOpError() + << contract.family << " lowering supports only " << supportedTypeText; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkGenericBinaryContract( + Operation *op, const VPTOBinaryContract &contract, Value src1, Value dst, + function_ref typePredicate, StringRef supportedTypeText) { + StringRef src1Layout = deriveTileLayout(src1); + StringRef dstLayout = deriveTileLayout(dst); + VPTOTileDomain src1Domain = deriveTileDomain(getMemorySpace(src1)); + VPTOTileDomain dstDomain = deriveTileDomain(getMemorySpace(dst)); + + bool hasPrecheckFailure = false; + if (contract.tileDomain != VPTOTileDomain::Vec || src1Domain != VPTOTileDomain::Vec || + dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires tile domain vec"; + hasPrecheckFailure = true; + } + if (contract.tileLayout != "row_major" || src1Layout != "row_major" || + dstLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || !typePredicate(contract.elementType)) { + op->emitOpError() + << contract.family << " lowering supports only " << supportedTypeText; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkRowReduceContract(Operation *op, + const VPTORowReduceContract &contract, + Value dst) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major") { + op->emitOpError() << contract.family << " lowering requires row-major source tile layout"; + hasPrecheckFailure = true; + } + if (contract.dstLayout != "row_major" && contract.dstLayout != "col_major") { + op->emitOpError() << contract.family + << " lowering requires row-major or col-major destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) { + op->emitOpError() << contract.family + << " lowering currently requires static source valid rows and cols"; + hasPrecheckFailure = true; + } + if (contract.validRows != dstRows) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows to match source valid rows"; + hasPrecheckFailure = true; + } + if (dstCols != 1) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols to equal 1"; + hasPrecheckFailure = true; + } + if (contract.dstLayout == "col_major") { + auto [dstRowsPhysical, dstColsPhysical] = getStaticTileRowsCols(dst); + (void)dstRowsPhysical; + if (dstColsPhysical != 1) { + op->emitOpError() << contract.family + << " lowering requires col-major destinations to use physical cols == 1"; + hasPrecheckFailure = true; + } + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkColReduceContract(Operation *op, + const VPTOColReduceContract &contract, + Value dst) { + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(dst, dstRows, dstCols); + + bool hasPrecheckFailure = false; + if (contract.srcDomain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.srcLayout != "row_major" || contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType || + (!contract.elementType.isF16() && !contract.elementType.isF32())) { + op->emitOpError() << contract.family << " lowering supports only f16 and f32 element types"; + hasPrecheckFailure = true; + } + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) { + op->emitOpError() << contract.family + << " lowering currently requires static source valid rows and cols"; + hasPrecheckFailure = true; + } + if (dstRows != 1) { + op->emitOpError() << contract.family + << " lowering requires destination valid rows to equal 1"; + hasPrecheckFailure = true; + } + if (dstCols != contract.validCols) { + op->emitOpError() << contract.family + << " lowering requires destination valid cols to match source valid cols"; + hasPrecheckFailure = true; + } + if (contract.isBinary && !contract.tmp) { + op->emitOpError() << contract.family << " lowering requires tmp for binary path"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult checkPartContract(Operation *op, const VPTOPartContract &contract) { + bool hasPrecheckFailure = false; + if (contract.src0Domain != VPTOTileDomain::Vec || + contract.src1Domain != VPTOTileDomain::Vec || + contract.dstDomain != VPTOTileDomain::Vec) { + op->emitOpError() << contract.family << " lowering requires vec source and destination"; + hasPrecheckFailure = true; + } + if (contract.src0Layout != "row_major" || contract.src1Layout != "row_major" || + contract.dstLayout != "row_major") { + op->emitOpError() << contract.family + << " lowering requires row-major source and destination tile layout"; + hasPrecheckFailure = true; + } + if (!contract.elementType) + hasPrecheckFailure = true; + else if (contract.family == "partadd") { + bool ok = contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16(); + if (auto intType = dyn_cast(contract.elementType)) + ok = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + if (!ok) { + op->emitOpError() << contract.family + << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; + hasPrecheckFailure = true; + } + } else { + bool ok = contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16(); + if (auto intType = dyn_cast(contract.elementType)) + ok = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + if (!ok) { + op->emitOpError() << contract.family + << " lowering supports f16, f32, bf16, and 8/16/32-bit integers"; + hasPrecheckFailure = true; + } + } + auto allStatic = [&](int64_t a, int64_t b) { + return a != ShapedType::kDynamic && b != ShapedType::kDynamic; + }; + if (!allStatic(contract.src0ValidRows, contract.src0ValidCols) || + !allStatic(contract.src1ValidRows, contract.src1ValidCols) || + !allStatic(contract.dstValidRows, contract.dstValidCols)) { + op->emitOpError() << contract.family + << " lowering currently requires static source and destination valid shapes"; + hasPrecheckFailure = true; + } + return failure(hasPrecheckFailure); +} + +LogicalResult lowerTLOAD(TLoadOp op, PatternRewriter &rewriter) { + VPTOLoadContract contract = extractTLoadContract(op); + if (contract.tileDomain != VPTOTileDomain::Vec) + return op.emitOpError("currently supports only VEC TLOAD lowering"); + + ResolvedTensorView sourceView; + if (!resolveTensorView(op.getSrc(), sourceView, rewriter, op.getLoc())) + return op.emitOpError("requires a recoverable source tensor view for VPTO lowering"); + + StringRef sourceLayout = + inferVecTransferLayoutFromTile(stringifyLayoutAttr(sourceView.layoutAttr), + contract.tileLayout); + bool isNdLoad = contract.tileLayout == "row_major" && sourceLayout == "nd"; + bool isDnLoad = contract.tileLayout == "col_major" && sourceLayout == "dn"; + if (!isNdLoad && !isDnLoad) + return op.emitOpError("currently supports only ND row_major or DN col_major vec TLOAD lowering"); + + Value sourceBuffer = + materializeBufferPointer(sourceView.root, getElementType(sourceView.root), + getGmMemorySpace(rewriter.getContext()), rewriter, + op.getLoc()); + Value destinationBuffer = + materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, op.getLoc()); + if (!sourceBuffer || !destinationBuffer) + return op.emitOpError("requires A5-compatible source and destination buffers"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); + (void)tileRows; + bool ubPad = contract.padMode != "none" || contract.padValue || + contract.leftPaddingNum || contract.rightPaddingNum; + Value validRowsValue = + materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = + materializeI64Value(contract.validColsValue, contract.validCols, rewriter, + op.getLoc()); + Value sidValue = rewriter.create(op.getLoc(), 0, 64); + int64_t elemBytes = getElementByteSize(contract.elementType); + if ((isNdLoad && tileCols == ShapedType::kDynamic) || + (isDnLoad && tileRows == ShapedType::kDynamic) || elemBytes <= 0) + return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); + VecNdTransferPlan plan; + LogicalResult planResult = + isNdLoad ? buildVecNdLoadPlan(sourceView.shape, sourceView.strides, tileCols, + contract.validColsValue, contract.validCols, + contract.elementType, rewriter, op.getLoc(), plan) + : buildVecDnLoadPlan(sourceView.shape, sourceView.strides, tileRows, + contract.validRowsValue, contract.validRows, + contract.elementType, rewriter, op.getLoc(), plan); + if (failed(planResult)) + return op.emitOpError("requires PTO-compatible vec copy_gm_to_ubuf arguments"); + Value leftPaddingValue = rewriter.create(op.getLoc(), 0, 64); + Value rightPaddingValue = rewriter.create(op.getLoc(), 0, 64); + Value cacheCtlValue = rewriter.create(op.getLoc(), 0, 64); + if (!validRowsValue || !validColsValue) + return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); + Value sourceOffset = + materializeI64Ofr(sourceView.offsetElems, rewriter, op.getLoc()); + if (!sourceOffset) + return op.emitOpError("requires a materializable source offset for VPTO lowering"); + Value sourceBase = adjustPointerByElemOffset(sourceBuffer, sourceOffset, elemBytes, + rewriter, op.getLoc()); + if (!sourceBase) + return op.emitOpError("failed to materialize source base pointer"); + + rewriter.create( + op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + rewriter.create( + op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + rewriter.create(op.getLoc(), plan.loop2Size, + plan.loop1Size); + + auto emitCopy = [&](Value srcPtr, Value dstPtr) { + Type transferElementType = + getCopyTransferElementType(contract.elementType, rewriter); + Value typedSrcPtr = + castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); + Value typedDstPtr = + castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); + if (!typedSrcPtr || !typedDstPtr) + return failure(); + Value dataSelectBitValue = + rewriter.create(op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(ubPad)); + rewriter.create( + op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, + plan.lenBurst, leftPaddingValue, rightPaddingValue, dataSelectBitValue, + cacheCtlValue, plan.firstStrideBytes, plan.secondStrideBytes); + return success(); + }; + + if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { + return emitCopy(sourceBase, destinationBuffer); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value outerUpper = + rewriter.create(op.getLoc(), rewriter.getIndexType(), + plan.outerCount); + auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), + outerLoop.getInductionVar()); + Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); + Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); + Value iterSrc = adjustPointerByElemOffset(sourceBase, srcStep, elemBytes, rewriter, + op.getLoc()); + Value iterDst = adjustPointerByElemOffset(destinationBuffer, dstStep, elemBytes, rewriter, + op.getLoc()); + return emitCopy(iterSrc, iterDst); +} + +LogicalResult lowerTABS(TAbsOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTAbsContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + + return buildUnaryVecScope("abs", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTADD(TAddOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTAddContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSUB(TSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTSubContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("sub", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMUL(TMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTMulContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("mul", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTDIV(TDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = extractTDivContract(op); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("div", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMAX(TMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("max", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("max", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMIN(TMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("min", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("min", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTAND(TAndOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("and", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("and", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTANDS(TAndSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tands"); +} + +LogicalResult lowerTOR(TOrOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("or", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("or", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTORS(TOrSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tors"); +} + +LogicalResult lowerTXOR(TXorOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOBinaryContract contract = buildBinaryContract("xor", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getSrc1(), op.getDst(), + [](Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("xor", contract, strategy, op.getSrc0(), + op.getSrc1(), op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTXORS(TXorSOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "txors"); +} + +LogicalResult lowerTEXP(TExpOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTExpContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("exp", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTLOG(TLogOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTLogContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("log", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTSQRT(TSqrtOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTSqrtContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("sqrt", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRSQRT(TRsqrtOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("rsqrt", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("trsqrt lowering requires a supported VPTO vector element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("trsqrt lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("trsqrt lowering requires valid rows and cols"); + + int64_t srcRowStride = deriveStaticRowStride(op.getSrc()); + int64_t dstRowStride = deriveStaticRowStride(op.getDst()); + if (srcRowStride == ShapedType::kDynamic || dstRowStride == ShapedType::kDynamic) + return op.emitOpError("trsqrt lowering requires static row-major row strides"); + + int64_t vectorWidth = vecType.getElementCount(); + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value srcRowStrideValue = + rewriter.create(op.getLoc(), srcRowStride); + Value dstRowStrideValue = + rewriter.create(op.getLoc(), dstRowStride); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + TypedAttr oneAttr = FloatAttr::get(contract.elementType, 1.0); + Value one = rewriter.create(op.getLoc(), oneAttr); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); + auto ones = + rewriter.create(op.getLoc(), vecType, one, fullMask, StringAttr()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value srcRowBase = rewriter.create( + op.getLoc(), rowLoop.getInductionVar(), srcRowStrideValue); + Value dstRowBase = rewriter.create( + op.getLoc(), rowLoop.getInductionVar(), dstRowStrideValue); + Value chunkOffset = chunkLoop.getInductionVar(); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkOffset); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkOffset); + Value remaining = rewriter.create(op.getLoc(), validColsValue, chunkOffset); + Value predicate = + buildPredicateMaskForLaneCount(rewriter, op.getLoc(), contract.elementType, remaining); + auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, + srcOffset, StringAttr()); + auto sqrt = rewriter.create(op.getLoc(), vecType, loaded.getResult(), + predicate); + auto result = rewriter.create(op.getLoc(), vecType, ones.getResult(), + sqrt.getResult(), predicate); + rewriter.create( + op.getLoc(), result.getResult(), dstBuffer, dstOffset, StringAttr(), predicate); + return success(); +} + +LogicalResult lowerTRECIP(TRecipOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTRecipContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, "f16 and f32 element types"))) + return failure(); + return buildUnaryVecScope("recip", contract, strategy, op.getSrc(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTNEG(TNegOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + + TypedAttr negOneAttr; + if (contract.elementType.isF16()) + negOneAttr = FloatAttr::get(contract.elementType, -1.0); + else if (contract.elementType.isF32()) + negOneAttr = FloatAttr::get(contract.elementType, -1.0); + else if (auto intType = dyn_cast(contract.elementType)) + negOneAttr = IntegerAttr::get(intType, -1); + else + return op.emitOpError("tneg lowering requires scalar element type"); + + Value negOne = rewriter.create(op.getLoc(), negOneAttr); + return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc(), negOne, + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTLRELU(TLReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("lrelu", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, + "f16 and f32 element types"))) + return failure(); + if (op.getSlope().getType() != contract.elementType) + return op.emitOpError("tlrelu lowering requires slope type to match source element type"); + return buildScalarUnaryVecScope("lrelu", contract, strategy, op.getSrc(), op.getSlope(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTCVT(TCvtOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("cvt", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32() || type.isBF16(); }, + "f16, f32, or bf16 element type"))) + return failure(); + + Type dstElementType = getElementType(op.getDst()); + FailureOr loweringKind = + classifyA5CvtLowering(contract.elementType, dstElementType); + if (failed(loweringKind)) + return op.emitOpError( + "current tcvt lowering supports only f32->f32, f32->bf16, f16->f32, and bf16->f32"); + + FailureOr roundMode = stringifyA5RoundMode(op, rewriter); + if (failed(roundMode)) + return op.emitOpError("tcvt lowering does not recognize the requested round mode"); + + auto srcVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + auto dstVecType = getVPTOVRegType(rewriter.getContext(), dstElementType); + if (!srcVecType || !dstVecType) + return op.emitOpError("tcvt lowering requires legal VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dstElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tcvt lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, + op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("tcvt lowering requires valid rows and cols"); + + int64_t vectorWidth = dstVecType.getElementCount(); + if (contract.validRows != ShapedType::kDynamic && + contract.validCols != ShapedType::kDynamic) { + int64_t totalElements = contract.validRows * contract.validCols; + if (totalElements % vectorWidth != 0) + return op.emitOpError( + "tcvt lowering requires total valid elements divisible by vector width"); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value totalElementsValue = + rewriter.create(op.getLoc(), validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); + OpBuilder::InsertionGuard chunkGuard(rewriter); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + switch (*loweringKind) { + case VPTOCvtLoweringKind::Vtrc: { + auto loaded = + rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + Value mask = buildAllPredicateMask(rewriter, op.getLoc(), dstElementType); + Value converted = rewriter.create(op.getLoc(), dstVecType, + loaded.getResult(), mask, + *roundMode); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::F32ToBF16: { + Value halfStep = rewriter.create( + op.getLoc(), srcVecType.getElementCount()); + Value upperOffset = + rewriter.create(op.getLoc(), offset, halfStep); + auto lower = + rewriter.create(op.getLoc(), srcVecType, srcBuffer, offset, StringAttr()); + auto upper = rewriter.create(op.getLoc(), srcVecType, srcBuffer, + upperOffset, StringAttr()); + Value odd = rewriter.create( + op.getLoc(), dstVecType, upper.getResult(), *roundMode, + rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_ODD")); + Value even = rewriter.create( + op.getLoc(), dstVecType, lower.getResult(), *roundMode, + rewriter.getStringAttr("RS_ENABLE"), rewriter.getStringAttr("PART_EVEN")); + Value merged = + rewriter.create( + op.getLoc(), dstVecType, even, odd, + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + rewriter.create( + op.getLoc(), merged, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::F16ToF32: { + auto loaded = rewriter.create( + op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); + Value converted = rewriter.create( + op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), + StringAttr(), rewriter.getStringAttr("PART_EVEN")); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + case VPTOCvtLoweringKind::BF16ToF32: { + auto loaded = rewriter.create( + op.getLoc(), srcVecType, srcBuffer, offset, rewriter.getStringAttr("UNPK_B16")); + Value converted = rewriter.create( + op.getLoc(), dstVecType, loaded.getResult(), StringAttr(), + StringAttr(), rewriter.getStringAttr("PART_EVEN")); + rewriter.create( + op.getLoc(), converted, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), dstElementType)); + break; + } + } + return success(); +} + +template +LogicalResult buildPackedCmp32VecScope(StringRef family, + const VPTOBinaryContract &contract, + Value dst, Value dstBuffer, + PatternRewriter &rewriter, Location loc, + CompareEmitter emitCompare) { + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return emitError(loc) << family << " lowering requires a supported vector element type"; + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, loc); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, loc); + if (!validRowsValue || !validColsValue) + return emitError(loc) << family << " lowering requires valid rows and cols"; + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return emitError(loc) << family << " lowering currently requires static valid rows and cols"; + + int64_t totalElements = contract.validRows * contract.validCols; + constexpr int64_t repeatElem = 64; + int64_t repeatTimes = (totalElements + repeatElem - 1) / repeatElem; + int64_t pairedRepeats = repeatTimes / 2; + int64_t remainRepeats = repeatTimes % 2; + + auto compareMaskType = + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType); + auto packedMaskType = getVPTOMaskType(rewriter.getContext(), "b8"); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value pairUpper = rewriter.create(loc, pairedRepeats); + Value repeatStep = rewriter.create(loc, repeatElem); + Value pairSrcStride = rewriter.create(loc, repeatElem * 2); + Value pairDstStride = rewriter.create(loc, 4); + Value laneCount = rewriter.create(loc, repeatElem, 32); + Value totalRemaining = rewriter.create(loc, totalElements, 32); + + FailureOr vecScope = + createLoopScopeRegion(loc, contract.loopScope, rewriter); + if (failed(vecScope)) + return emitError(loc) << "failed to create AIV vector scope region"; + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto pairLoop = + rewriter.create(loc, c0, pairUpper, c1, ValueRange{totalRemaining}); + rewriter.setInsertionPointToStart(pairLoop.getBody()); + Value remaining = pairLoop.getRegionIterArgs().front(); + Value pairBase = rewriter.create(loc, pairLoop.getInductionVar(), + pairSrcStride); + Value pairNext = rewriter.create(loc, pairBase, repeatStep); + Value dstOffset = rewriter.create(loc, pairLoop.getInductionVar(), + pairDstStride); + Value dstBase = adjustPointerByElemOffset(dstBuffer, dstOffset, 4, rewriter, loc); + Value dstZero = rewriter.create(loc, 0); + auto pairMask0 = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + remaining); + auto pairMask1 = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + pairMask0.getScalarOut()); + Value cmp0 = emitCompare(rewriter, loc, pairBase, pairMask0.getMask()); + Value cmp1 = emitCompare(rewriter, loc, pairNext, pairMask1.getMask()); + Value packedCmp0 = rewriter + .create(loc, packedMaskType, cmp0, + rewriter.getStringAttr("LOWER")) + .getResult(); + Value packedCmp1 = rewriter + .create(loc, packedMaskType, cmp1, + rewriter.getStringAttr("LOWER")) + .getResult(); + auto interleaved = rewriter.create( + loc, packedMaskType, packedMaskType, packedCmp0, packedCmp1); + rewriter.create(loc, interleaved.getLow(), dstBase, dstZero, + "NORM"); + rewriter.create(loc, pairMask1.getScalarOut()); + + if (remainRepeats == 0) + return success(); + + rewriter.setInsertionPointAfter(pairLoop); + Value tailBase = rewriter.create(loc, pairedRepeats * repeatElem * 2); + Value tailDst = rewriter.create(loc, pairedRepeats * 4); + Value tailDstBase = adjustPointerByElemOffset(dstBuffer, tailDst, 4, rewriter, loc); + Value tailDstZero = rewriter.create(loc, 0); + auto tailMask = rewriter.create(loc, compareMaskType, + rewriter.getI32Type(), + pairLoop.getResult(0)); + Value tailCmp = emitCompare(rewriter, loc, tailBase, tailMask.getMask()); + Value packedTail = rewriter + .create(loc, packedMaskType, tailCmp, + rewriter.getStringAttr("LOWER")) + .getResult(); + rewriter.create(loc, packedTail, tailDstBase, tailDstZero, + "NORM"); + return success(); +} + +LogicalResult lowerTCmpS(TCmpSOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("cmps", op.getSrc()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tcmps lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tcmps lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tcmps lowering requires static valid shape"); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + if (contract.validRows != dstRows || contract.validCols != dstCols) + return op.emitOpError("tcmps lowering requires matching source and destination valid region"); + if (!isSupportedPackedCmp32ElementType(contract.elementType)) + return op.emitOpError("tcmps lowering currently supports only 32-bit source tiles"); + auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); + if (!dstElemType || !dstElemType.isUnsignedInteger(8)) + return op.emitOpError("tcmps lowering currently requires ui8 destination tiles"); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tcmps lowering requires scalar type to match source element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tcmps lowering requires pointer-backed tile buffers"); + + StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); + return buildPackedCmp32VecScope( + "tcmps", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), + [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, + Value mask) -> Value { + auto vecType = + getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); + auto loaded = + nestedRewriter.create(nestedLoc, vecType, srcBuffer, offset, StringAttr()); + return nestedRewriter + .create(nestedLoc, + getVPTOMaskTypeForElementType( + nestedRewriter.getContext(), + contract.elementType), + loaded.getResult(), op.getScalar(), mask, cmpMode) + .getResult(); + }); +} + +LogicalResult lowerTCmp(TCmpOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("cmp", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tcmp lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || + deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tcmp lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tcmp lowering requires static valid shape"); + int64_t src1Rows = ShapedType::kDynamic; + int64_t src1Cols = ShapedType::kDynamic; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getSrc1(), src1Rows, src1Cols); + deriveValidShape(op.getDst(), dstRows, dstCols); + if (contract.validRows != src1Rows || contract.validCols != src1Cols || + contract.validRows != dstRows || contract.validCols != dstCols) + return op.emitOpError("tcmp lowering requires matching source and destination valid region"); + if (!isSupportedPackedCmp32ElementType(contract.elementType)) + return op.emitOpError("tcmp lowering currently supports only 32-bit source tiles"); + if (getElementType(op.getSrc1()) != contract.elementType) + return op.emitOpError("tcmp lowering requires src1 element type to match src0"); + auto dstElemType = dyn_cast_or_null(getElementType(op.getDst())); + if (!dstElemType || !dstElemType.isUnsignedInteger(8)) + return op.emitOpError("tcmp lowering currently requires ui8 destination tiles"); + + Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, + getMemorySpace(op.getSrc0()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, + getMemorySpace(op.getSrc1()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), getElementType(op.getDst()), + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tcmp lowering requires pointer-backed tile buffers"); + + StringAttr cmpMode = rewriter.getStringAttr(stringifyCmpModeAttr(op.getCmpModeAttr())); + return buildPackedCmp32VecScope( + "tcmp", contract, op.getDst(), dstBuffer, rewriter, op.getLoc(), + [&](PatternRewriter &nestedRewriter, Location nestedLoc, Value offset, + Value mask) -> Value { + auto vecType = + getVPTOVRegType(nestedRewriter.getContext(), contract.elementType); + auto lhs = + nestedRewriter.create(nestedLoc, vecType, src0Buffer, offset, StringAttr()); + auto rhs = + nestedRewriter.create(nestedLoc, vecType, src1Buffer, offset, StringAttr()); + return nestedRewriter + .create(nestedLoc, + getVPTOMaskTypeForElementType( + nestedRewriter.getContext(), + contract.elementType), + lhs.getResult(), rhs.getResult(), mask, cmpMode) + .getResult(); + }); +} + +LogicalResult lowerTCI(TCIOp op, PatternRewriter &rewriter) { + Type elementType = getElementType(op.getDst()); + auto intType = dyn_cast_or_null(elementType); + if (!intType || (intType.getWidth() != 16 && intType.getWidth() != 32)) + return op.emitOpError("tci lowering requires i16 or i32 destination element type"); + if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("tci lowering requires tile domain vec"); + if (deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("tci lowering requires row-major tile layout"); + + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + Value validRowsValue; + Value validColsValue; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (validRows != 1) + return op.emitOpError("tci lowering currently requires valid rows == 1"); + + Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!dstBuffer) + return op.emitOpError("tci lowering requires pointer-backed destination tile buffer"); + + Value upperBound = materializeIndexValue(validColsValue, validCols, rewriter, op.getLoc()); + if (!upperBound) + return op.emitOpError("tci lowering requires valid cols"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + auto loop = rewriter.create(op.getLoc(), c0, upperBound, c1); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(loop.getBody()); + Value iv = loop.getInductionVar(); + Value ivAsElem = rewriter.create(op.getLoc(), intType, iv); + Value stored = + op.getDescending() + ? rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult() + : rewriter.create(op.getLoc(), op.getS(), ivAsElem).getResult(); + rewriter.create(op.getLoc(), dstBuffer, iv, stored); + return success(); +} + +LogicalResult lowerTRELU(TReluOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTReluContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + return type.isF16() || type.isF32() || + (isa(type) && cast(type).getWidth() == 32); + }, + "f16, f32, and i32 element types"))) + return failure(); + return buildUnaryVecScope("relu", contract, strategy, op.getSrc(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTNOT(TNotOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = extractTNotContract(op); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildUnaryVecScope("not", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTTRANS(TTransOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = buildUnaryContract("trans", op.getSrc()); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("ttrans lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("ttrans lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || + dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return op.emitOpError("ttrans lowering requires static valid shape"); + if (contract.validRows != dstCols || contract.validCols != dstRows) + return op.emitOpError("ttrans lowering requires transposed source/destination valid shape"); + if (contract.elementType != getElementType(op.getDst())) + return op.emitOpError("ttrans lowering requires matching source/destination element type"); + + int64_t elemBytes = getElementByteSize(contract.elementType); + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (elemBytes != 4) + return op.emitOpError("ttrans lowering currently supports only b32 element types"); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("ttrans lowering requires static source/destination row stride"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + auto indexElemType = rewriter.getIntegerType(32); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElemType); + if (!dataVecType || !indexVecType) + return op.emitOpError("ttrans lowering requires supported VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("ttrans lowering requires pointer-backed tile buffers"); + + constexpr int64_t repeatBytes = 256; + constexpr int64_t blockBytes = 32; + int64_t elementsPerRepeat = repeatBytes / elemBytes; + int64_t blockSizeElem = blockBytes / elemBytes; + int64_t alignedRows = + llvm::divideCeil(contract.validRows, blockSizeElem) * blockSizeElem; + int64_t repeatTimes = llvm::divideCeil(alignedRows, elementsPerRepeat); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value colsUpper = rewriter.create(op.getLoc(), contract.validCols); + Value chunkUpper = rewriter.create(op.getLoc(), repeatTimes); + Value elementsPerRepeatValue = + rewriter.create(op.getLoc(), elementsPerRepeat); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value srcStrideI32 = rewriter.create(op.getLoc(), srcStride, 32); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, c1); + rewriter.setInsertionPointToStart(colLoop.getBody()); + auto chunkLoop = rewriter.create(op.getLoc(), c0, chunkUpper, c1); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value chunkBase = rewriter.create(op.getLoc(), chunkLoop.getInductionVar(), + elementsPerRepeatValue); + Value colI32 = rewriter.create(op.getLoc(), indexElemType, + colLoop.getInductionVar()); + Value chunkBaseI32 = + rewriter.create(op.getLoc(), indexElemType, chunkBase); + auto indices = + rewriter.create(op.getLoc(), indexVecType, chunkBaseI32, + rewriter.getStringAttr("INC_ORDER")); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), indexElemType); + auto scaled = rewriter.create(op.getLoc(), indexVecType, + indices.getResult(), srcStrideI32, fullMask); + auto offsets = rewriter.create(op.getLoc(), indexVecType, + scaled.getResult(), colI32, fullMask); + Value fullActiveLanes = + rewriter.create(op.getLoc(), + dataVecType.getElementCount()); + auto gathered = + rewriter.create(op.getLoc(), dataVecType, srcBuffer, + offsets.getResult(), fullActiveLanes); + Value dstBase = + rewriter.create(op.getLoc(), colLoop.getInductionVar(), dstStrideValue); + Value dstOffset = rewriter.create(op.getLoc(), dstBase, chunkBase); + rewriter.create( + op.getLoc(), gathered.getResult(), dstBuffer, dstOffset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); + return success(); +} + +template +LogicalResult lowerTFillPadCommon(FillPadOpTy op, PatternRewriter &rewriter, + bool allowDstExpand) { + VPTOUnaryContract contract = buildUnaryContract("fillpad", op.getSrc()); + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstRows, dstCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec) + return op.emitOpError("fillpad lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError("fillpad lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || contract.validCols == ShapedType::kDynamic || + dstRows == ShapedType::kDynamic || dstCols == ShapedType::kDynamic) + return op.emitOpError("fillpad lowering requires static valid shape"); + if (!allowDstExpand && (contract.validRows != dstRows || contract.validCols != dstCols)) + return op.emitOpError("tfillpad lowering requires matching source/destination valid shape"); + if (allowDstExpand && (dstRows < contract.validRows || dstCols < contract.validCols)) + return op.emitOpError("tfillpad_expand lowering requires dst shape >= src shape"); + if (contract.elementType != getElementType(op.getDst())) + return op.emitOpError("fillpad lowering requires matching source/destination element type"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("fillpad lowering requires static source/destination row stride"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("fillpad lowering requires supported VPTO vector element type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("fillpad lowering requires pointer-backed tile buffers"); + + auto config = lookupTileConfig(op.getDst()); + PadValueAttr padAttr = config ? dyn_cast(config.getPad()) : PadValueAttr{}; + Attribute padValueAttr = buildFillPadValue(contract.elementType, padAttr, rewriter); + if (!padValueAttr) + return op.emitOpError("fillpad lowering requires a concrete non-null dst pad value"); + Value padScalar = rewriter.create(op.getLoc(), cast(padValueAttr)); + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), vecType.getElementType()); + auto padVec = + rewriter.create(op.getLoc(), vecType, padScalar, fullMask, StringAttr()); + + int64_t vectorWidth = vecType.getElementCount(); + int64_t padCols = dstCols - contract.validCols; + int64_t padRows = dstRows - contract.validRows; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value srcRowsUpper = rewriter.create(op.getLoc(), contract.validRows); + Value srcColsUpper = rewriter.create(op.getLoc(), contract.validCols); + Value dstRowsUpper = rewriter.create(op.getLoc(), dstRows); + Value vectorStep = rewriter.create(op.getLoc(), vectorWidth); + Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value validColsValue = rewriter.create(op.getLoc(), contract.validCols); + Value dstColsValue = rewriter.create(op.getLoc(), dstCols); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + auto rowLoop = rewriter.create(op.getLoc(), c0, srcRowsUpper, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value srcRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), + srcStrideValue); + Value dstRowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), + dstStrideValue); + + auto copyChunkLoop = + rewriter.create(op.getLoc(), c0, srcColsUpper, vectorStep); + rewriter.setInsertionPointToStart(copyChunkLoop.getBody()); + Value copyOffset = + rewriter.create(op.getLoc(), srcRowBase, copyChunkLoop.getInductionVar()); + auto loaded = rewriter.create(op.getLoc(), vecType, srcBuffer, + copyOffset, StringAttr()); + Value copyDstOffset = + rewriter.create(op.getLoc(), dstRowBase, copyChunkLoop.getInductionVar()); + Value copyRemaining = + rewriter.create(op.getLoc(), validColsValue, copyChunkLoop.getInductionVar()); + auto copyNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, + copyRemaining, vectorStep); + Value copyActiveLanes = + rewriter.create(op.getLoc(), copyNeedsClamp, copyRemaining, vectorStep); + Value copyMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, copyActiveLanes); + rewriter.create(op.getLoc(), loaded.getResult(), dstBuffer, + copyDstOffset, StringAttr(), copyMask); + + rewriter.setInsertionPointAfter(copyChunkLoop); + if (padCols > 0) { + Value padColsUpper = rewriter.create(op.getLoc(), padCols); + auto padColLoop = rewriter.create(op.getLoc(), c0, padColsUpper, vectorStep); + rewriter.setInsertionPointToStart(padColLoop.getBody()); + Value padDstStart = rewriter.create(op.getLoc(), dstRowBase, validColsValue); + Value padDstOffset = rewriter.create(op.getLoc(), padDstStart, + padColLoop.getInductionVar()); + Value padRemaining = + rewriter.create(op.getLoc(), padColsUpper, padColLoop.getInductionVar()); + auto padNeedsClamp = rewriter.create(op.getLoc(), arith::CmpIPredicate::slt, + padRemaining, vectorStep); + Value padActiveLanes = + rewriter.create(op.getLoc(), padNeedsClamp, padRemaining, vectorStep); + Value padMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, padActiveLanes); + rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, + padDstOffset, StringAttr(), padMask); + } + + rewriter.setInsertionPointAfter(rowLoop); + if (padRows > 0) { + Value bottomStart = rewriter.create(op.getLoc(), srcRowsUpper, dstStrideValue); + Value bottomElements = + rewriter.create(op.getLoc(), + rewriter.create(op.getLoc(), dstRowsUpper, + dstColsValue), + bottomStart); + auto bottomLoop = rewriter.create(op.getLoc(), c0, bottomElements, vectorStep); + rewriter.setInsertionPointToStart(bottomLoop.getBody()); + Value bottomDstOffset = + rewriter.create(op.getLoc(), bottomStart, bottomLoop.getInductionVar()); + Value bottomRemaining = + rewriter.create(op.getLoc(), bottomElements, bottomLoop.getInductionVar()); + auto bottomNeedsClamp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::slt, bottomRemaining, vectorStep); + Value bottomActiveLanes = rewriter.create( + op.getLoc(), bottomNeedsClamp, bottomRemaining, vectorStep); + Value bottomMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, bottomActiveLanes); + rewriter.create(op.getLoc(), padVec.getResult(), dstBuffer, + bottomDstOffset, StringAttr(), bottomMask); + } + + return success(); +} + +LogicalResult lowerTFILLPAD(TFillPadOp op, PatternRewriter &rewriter) { + return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/false); +} + +LogicalResult lowerTFILLPADExpand(TFillPadExpandOp op, PatternRewriter &rewriter) { + return lowerTFillPadCommon(op, rewriter, /*allowDstExpand=*/true); +} + +LogicalResult lowerTExpandS(TExpandsOp op, PatternRewriter &rewriter) { + VPTOUnaryContract contract = extractTExpandSContract(op); + if (contract.tileDomain != VPTOTileDomain::Vec) + return op.emitOpError("expands lowering requires tile domain vec"); + if (contract.tileLayout != "row_major") + return op.emitOpError("expands lowering requires row-major tile layout"); + if (!contract.elementType) + return op.emitOpError("expands lowering requires a concrete element type"); + + Type scalarType = op.getScalar().getType(); + if (scalarType != contract.elementType) + return op.emitOpError("expands lowering requires scalar type to match destination element type"); + + if (!(contract.elementType.isF16() || contract.elementType.isF32() || + contract.elementType.isBF16())) { + if (auto intType = dyn_cast(contract.elementType)) { + unsigned width = intType.getWidth(); + if (width != 8 && width != 16 && width != 32) + return op.emitOpError("expands lowering supports only f16, f32, bf16, and 8/16/32-bit integer element types"); + } else { + return op.emitOpError("expands lowering supports only scalar integer or floating-point element types"); + } + } + + return buildExpandScalarVecScope(contract, op.getScalar(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTGather(TGatherOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tgather lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tgather lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getSrc()); + if (dataElementType != getElementType(op.getDst())) + return op.emitOpError("tgather lowering requires matching src/dst element type"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + if (!dataVecType) + return op.emitOpError("tgather lowering requires supported VPTO data type"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tgather lowering requires pointer-backed tile buffers"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic) + return op.emitOpError("tgather lowering requires static row stride"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + if (Value indices = op.getIndices()) { + if (failed(requireVecRowMajor(indices, "indices"))) + return failure(); + + Type indexElementType = getElementType(indices); + auto indexIntegerType = dyn_cast(indexElementType); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); + if (!indexIntegerType || !indexVecType) + return op.emitOpError("tgather index lowering requires integer indices with supported VPTO vector type"); + if (indexVecType.getElementCount() != dataVecType.getElementCount()) + return op.emitOpError("tgather index lowering currently requires matching data/index vector widths"); + + Value indexBuffer = materializeBufferPointer(indices, indexElementType, + getMemorySpace(indices), rewriter, + op.getLoc()); + if (!indexBuffer) + return op.emitOpError("tgather index lowering requires pointer-backed indices tile"); + + int64_t indexStride = deriveStaticRowStride(indices); + if (indexStride == ShapedType::kDynamic) + return op.emitOpError("tgather index lowering requires static index row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgather index lowering requires valid dst shape"); + + int64_t chunkWidth = indexVecType.getElementCount(); + Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + Value indexStrideValue = + rewriter.create(op.getLoc(), indexStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value dstRowBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value indexRowBase = + rewriter.create(op.getLoc(), row, indexStrideValue); + Value indexOffset = + rewriter.create(op.getLoc(), indexRowBase, chunkBase); + auto offsetVector = rewriter.create(op.getLoc(), indexVecType, + indexBuffer, indexOffset, + StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsetVector.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, chunkWidth); + } + + auto maskPattern = op.getMaskPatternAttr(); + if (!maskPattern) + return op.emitOpError("tgather lowering requires indices or maskPattern"); + if (maskPattern.getValue() != MaskPattern::P1111) + return op.emitOpError("tgather mask lowering currently supports only maskPattern=P1111"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); + deriveValidShape(op.getSrc(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getSrc(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgather mask lowering requires valid src shape"); + + int64_t chunkWidth = dataVecType.getElementCount(); + Value chunkStep = rewriter.create(op.getLoc(), chunkWidth); + Value srcStrideValue = + rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value srcRowBase = + rewriter.create(op.getLoc(), row, srcStrideValue); + Value dstRowBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkBase); + auto loaded = rewriter.create(op.getLoc(), dataVecType, srcBuffer, + srcOffset, StringAttr()); + Value dstOffset = + rewriter.create(op.getLoc(), dstRowBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), loaded.getResult(), dstBuffer, + dstOffset, activeLanes, chunkWidth); +} + +LogicalResult lowerTGatherB(TGatherBOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tgatherb lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tgatherb lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getOffsets(), "offsets")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getDst()); + if (getElementType(op.getSrc()) != dataElementType) + return op.emitOpError("tgatherb lowering requires matching src/dst element type"); + + auto offsetIntegerType = dyn_cast(getElementType(op.getOffsets())); + if (!offsetIntegerType || offsetIntegerType.getWidth() != 32 || + !offsetIntegerType.isUnsigned()) + return op.emitOpError("tgatherb lowering currently requires unsigned 32-bit offsets"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + auto offsetVecType = + getVPTOVRegType(rewriter.getContext(), getElementType(op.getOffsets())); + if (!dataVecType || !offsetVecType) + return op.emitOpError("tgatherb lowering requires supported VPTO vector types"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + Value offsetBuffer = + materializeBufferPointer(op.getOffsets(), getElementType(op.getOffsets()), + getMemorySpace(op.getOffsets()), rewriter, op.getLoc()); + if (!srcBuffer || !dstBuffer || !offsetBuffer) + return op.emitOpError("tgatherb lowering requires pointer-backed tile buffers"); + + int64_t dstStride = deriveStaticRowStride(op.getDst()); + int64_t offsetStride = deriveStaticRowStride(op.getOffsets()); + int64_t staticRows = deriveStaticShapeDim(op.getDst(), 0); + int64_t staticCols = deriveStaticShapeDim(op.getDst(), 1); + if (dstStride == ShapedType::kDynamic || offsetStride == ShapedType::kDynamic || + staticRows == ShapedType::kDynamic || staticCols == ShapedType::kDynamic) + return op.emitOpError("tgatherb lowering requires static tile shape and row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getDst(), validRowsValue, validColsValue); + deriveValidShape(op.getDst(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getDst(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tgatherb lowering requires valid dst shape"); + + unsigned elemBytes = dataElementType.getIntOrFloatBitWidth() / 8; + int64_t elementsPerRepeat = 256 / elemBytes; + int64_t blockSizeElem = 32 / elemBytes; + int64_t staticRepeatTimes = llvm::divideCeil(staticCols, elementsPerRepeat); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value elementsPerRepeatValue = + rewriter.create(op.getLoc(), elementsPerRepeat); + Value blockSizeElemValue = + rewriter.create(op.getLoc(), blockSizeElem); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstStride); + Value offsetStrideValue = + rewriter.create(op.getLoc(), offsetStride); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + if (staticRepeatTimes > staticRows) { + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, + elementsPerRepeatValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, + elementsPerRepeatValue); + Value rowOffsetBase = + rewriter.create(op.getLoc(), row, offsetStrideValue); + Value rowDstBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value offsetChunkBase = + rewriter.create(op.getLoc(), chunkBase, + blockSizeElemValue); + Value offsetLoadOffset = + rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); + auto offsets = rewriter.create(op.getLoc(), offsetVecType, + offsetBuffer, offsetLoadOffset, + StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), rowDstBase, chunkBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, + dataVecType.getElementCount()); + } + + auto chunkLoop = rewriter.create(op.getLoc(), c0, validColsValue, + elementsPerRepeatValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + + Value chunkBase = chunkLoop.getInductionVar(); + Value row = rowLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = buildMinIndexValue(rewriter, op.getLoc(), remaining, + elementsPerRepeatValue); + Value rowOffsetBase = + rewriter.create(op.getLoc(), row, offsetStrideValue); + Value rowDstBase = + rewriter.create(op.getLoc(), row, dstStrideValue); + Value offsetChunkBase = + rewriter.create(op.getLoc(), chunkBase, + blockSizeElemValue); + Value offsetLoadOffset = + rewriter.create(op.getLoc(), rowOffsetBase, offsetChunkBase); + auto offsets = rewriter.create(op.getLoc(), offsetVecType, offsetBuffer, + offsetLoadOffset, StringAttr()); + auto gathered = rewriter.create( + op.getLoc(), dataVecType, srcBuffer, offsets.getResult(), activeLanes); + Value dstOffset = + rewriter.create(op.getLoc(), chunkBase, rowDstBase); + return buildMaskedVectorStore(rewriter, op.getLoc(), gathered.getResult(), + dstBuffer, dstOffset, activeLanes, + dataVecType.getElementCount()); +} + +LogicalResult lowerTScatter(TScatterOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tscatter lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tscatter lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getIndexes(), "indexes")) || + failed(requireVecRowMajor(op.getDst(), "dst"))) + return failure(); + + Type dataElementType = getElementType(op.getSrc()); + if (dataElementType != getElementType(op.getDst())) + return op.emitOpError("tscatter lowering requires matching src/dst element type"); + + Type indexElementType = getElementType(op.getIndexes()); + auto indexIntegerType = dyn_cast(indexElementType); + if (!indexIntegerType || indexIntegerType.getWidth() != 32) + return op.emitOpError("tscatter lowering currently requires 32-bit integer indexes"); + + auto dataVecType = getVPTOVRegType(rewriter.getContext(), dataElementType); + auto indexVecType = getVPTOVRegType(rewriter.getContext(), indexElementType); + if (!dataVecType || !indexVecType || + dataVecType.getElementCount() != indexVecType.getElementCount()) + return op.emitOpError("tscatter lowering currently requires matching data/index vector widths"); + + Value srcBuffer = materializeBufferPointer(op.getSrc(), dataElementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), dataElementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + Value indexBuffer = materializeBufferPointer(op.getIndexes(), indexElementType, + getMemorySpace(op.getIndexes()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer || !indexBuffer) + return op.emitOpError("tscatter lowering requires pointer-backed tile buffers"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t indexStride = deriveStaticRowStride(op.getIndexes()); + if (srcStride == ShapedType::kDynamic || indexStride == ShapedType::kDynamic) + return op.emitOpError("tscatter lowering requires static src/index row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getIndexes(), validRowsValue, validColsValue); + deriveValidShape(op.getIndexes(), validRows, validCols); + if (failed(resolveExecutionValidShape(op.getIndexes(), validRowsValue, validColsValue, + validRows, validCols, rewriter, op.getLoc()))) + return op.emitOpError("tscatter lowering requires valid index shape"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value chunkStep = rewriter.create( + op.getLoc(), indexVecType.getElementCount()); + Value srcStrideValue = + rewriter.create(op.getLoc(), srcStride); + Value indexStrideValue = + rewriter.create(op.getLoc(), indexStride); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, validColsValue, chunkStep); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + + Value row = rowLoop.getInductionVar(); + Value chunkBase = chunkLoop.getInductionVar(); + Value remaining = + rewriter.create(op.getLoc(), validColsValue, chunkBase); + Value activeLanes = + buildMinIndexValue(rewriter, op.getLoc(), remaining, chunkStep); + + Value srcRowBase = + rewriter.create(op.getLoc(), row, srcStrideValue); + Value indexRowBase = + rewriter.create(op.getLoc(), row, indexStrideValue); + Value srcOffset = + rewriter.create(op.getLoc(), srcRowBase, chunkBase); + Value indexOffset = + rewriter.create(op.getLoc(), indexRowBase, chunkBase); + auto srcVector = rewriter.create(op.getLoc(), dataVecType, srcBuffer, + srcOffset, StringAttr()); + auto indexVector = rewriter.create(op.getLoc(), indexVecType, indexBuffer, + indexOffset, StringAttr()); + rewriter.create(op.getLoc(), srcVector.getResult(), dstBuffer, + indexVector.getResult(), activeLanes); + return success(); +} + +LogicalResult lowerTMrgSort(TMrgSortOp op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tmrgsort lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tmrgsort lowering requires row-major layout for " + << role; + return success(); + }; + auto requireOneRow = [&](Value value, StringRef role) -> LogicalResult { + if (deriveStaticShapeDim(value, 0) != 1) + return op.emitOpError() << "tmrgsort lowering requires rows==1 for " << role; + return success(); + }; + + Location loc = op.getLoc(); + if (op.isFormat1()) { + Value src = op.getSrcs().front(); + Value dst = op.getDsts().front(); + if (failed(requireVecRowMajor(src, "src")) || failed(requireVecRowMajor(dst, "dst")) || + failed(requireOneRow(src, "src")) || failed(requireOneRow(dst, "dst"))) + return failure(); + + Type elementType = getElementType(src); + if (elementType != getElementType(dst)) + return op.emitOpError("tmrgsort format1 requires matching src/dst element type"); + if (!(elementType.isF16() || elementType.isF32())) + return op.emitOpError("tmrgsort format1 currently supports only f16/f32"); + + Value srcBuffer = materializeBufferPointer(src, elementType, getMemorySpace(src), + rewriter, loc); + Value dstBuffer = materializeBufferPointer(dst, elementType, getMemorySpace(dst), + rewriter, loc); + if (!srcBuffer || !dstBuffer) + return op.emitOpError("tmrgsort format1 requires pointer-backed tile buffers"); + + Value blockLen = op.getBlockLen(); + if (!blockLen) + return op.emitOpError("tmrgsort format1 requires blockLen"); + Value blockLenI64; + if (blockLen.getType().isIndex()) + blockLenI64 = + rewriter.create(loc, rewriter.getI64Type(), blockLen); + else + blockLenI64 = + rewriter.create(loc, rewriter.getI64Type(), blockLen); + Value blockLenIndex = + rewriter.create(loc, rewriter.getIndexType(), blockLenI64); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(src, validRowsValue, validColsValue); + deriveValidShape(src, validRows, validCols); + Value validColsI64 = materializeI64Value(validColsValue, validCols, rewriter, loc); + + int64_t elemBytes = getElementByteSize(elementType); + Value numStructures = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create( + loc, blockLenI64, rewriter.create(loc, elemBytes, 64)), + rewriter.create(loc, 3, 64)); + Value count = buildPackedCountI64(rewriter, loc, + {numStructures, numStructures, numStructures, numStructures}); + Value repeatTimes = rewriter.create( + loc, validColsI64, + rewriter.create( + loc, blockLenI64, rewriter.create(loc, 4, 64))); + Value config = rewriter.create( + loc, repeatTimes, rewriter.create(loc, 0b1111 << 8, 64)); + + Value src0 = srcBuffer; + Value src1 = offsetBufferPointer(srcBuffer, elementType, blockLenIndex, rewriter, loc); + Value src2 = offsetBufferPointer( + srcBuffer, elementType, + rewriter.create(loc, blockLenIndex, + rewriter.create(loc, 2)), + rewriter, loc); + Value src3 = offsetBufferPointer( + srcBuffer, elementType, + rewriter.create(loc, blockLenIndex, + rewriter.create(loc, 3)), + rewriter, loc); + rewriter.create(loc, dstBuffer, src0, src1, src2, src3, count, + config); + return success(); + } + + if (!op.isFormat2()) + return op.emitOpError("unsupported tmrgsort format for current vpto backend"); + if (op.getExhausted()) + return op.emitOpError("tmrgsort format2 exhausted=true is not yet supported"); + if (op.getSrcs().size() != 4 || op.getDsts().size() != 2) + return op.emitOpError("tmrgsort format2 currently requires exactly 4 srcs and 2 dsts"); + + Type elementType = getElementType(op.getSrcs().front()); + if (!(elementType.isF16() || elementType.isF32())) + return op.emitOpError("tmrgsort format2 currently supports only f16/f32"); + + SmallVector srcBuffers; + SmallVector srcCounts; + srcBuffers.reserve(4); + srcCounts.reserve(4); + for (Value src : op.getSrcs()) { + if (failed(requireVecRowMajor(src, "src")) || failed(requireOneRow(src, "src"))) + return failure(); + if (getElementType(src) != elementType) + return op.emitOpError("tmrgsort format2 requires matching source element types"); + + Value srcBuffer = + materializeBufferPointer(src, elementType, getMemorySpace(src), rewriter, loc); + if (!srcBuffer) + return op.emitOpError("tmrgsort format2 requires pointer-backed source tiles"); + srcBuffers.push_back(srcBuffer); + + Value rowsValue; + Value colsValue; + int64_t rows = ShapedType::kDynamic; + int64_t cols = ShapedType::kDynamic; + deriveValidShapeValues(src, rowsValue, colsValue); + deriveValidShape(src, rows, cols); + Value colsI64 = materializeI64Value(colsValue, cols, rewriter, loc); + srcCounts.push_back(rewriter.create( + loc, rewriter.getI64Type(), colsI64, + rewriter.create(loc, elementType.isF32() ? 1 : 2, 64))); + } + + Value dst = op.getDsts()[0]; + Value tmp = op.getDsts()[1]; + if (failed(requireVecRowMajor(dst, "dst")) || failed(requireVecRowMajor(tmp, "tmp")) || + failed(requireOneRow(dst, "dst")) || failed(requireOneRow(tmp, "tmp"))) + return failure(); + if (getElementType(dst) != elementType || getElementType(tmp) != elementType) + return op.emitOpError("tmrgsort format2 requires matching dst/tmp element types"); + + Value dstBuffer = + materializeBufferPointer(dst, elementType, getMemorySpace(dst), rewriter, loc); + Value tmpBuffer = + materializeBufferPointer(tmp, elementType, getMemorySpace(tmp), rewriter, loc); + if (!dstBuffer || !tmpBuffer) + return op.emitOpError("tmrgsort format2 requires pointer-backed dst/tmp tiles"); + + Value count = buildPackedCountI64(rewriter, loc, srcCounts); + Value config = + rewriter.create(loc, 1 | (0b1111 << 8), 64); + rewriter.create(loc, tmpBuffer, srcBuffers[0], srcBuffers[1], + srcBuffers[2], srcBuffers[3], count, config); + + Value dstRowsValue; + Value dstColsValue; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + deriveValidShapeValues(dst, dstRowsValue, dstColsValue); + deriveValidShape(dst, dstRows, dstCols); + Value dstColsI64 = materializeI64Value(dstColsValue, dstCols, rewriter, loc); + int64_t elemBytes = getElementByteSize(elementType); + Value lenBurst = buildCeilDivPositiveI64( + rewriter, loc, + rewriter.create( + loc, dstColsI64, rewriter.create(loc, elemBytes, 64)), + 32); + Value zeroI64 = rewriter.create(loc, 0, 64); + Value oneI64 = rewriter.create(loc, 1, 64); + rewriter.create(loc, tmpBuffer, dstBuffer, zeroI64, oneI64, + lenBurst, zeroI64, zeroI64); + return success(); +} + +LogicalResult lowerTSort32(TSort32Op op, PatternRewriter &rewriter) { + auto requireVecRowMajor = [&](Value value, StringRef role) -> LogicalResult { + if (deriveTileDomain(getMemorySpace(value)) != VPTOTileDomain::Vec) + return op.emitOpError() << "tsort32 lowering requires vec tile domain for " + << role; + if (deriveTileLayout(value) != "row_major") + return op.emitOpError() << "tsort32 lowering requires row-major layout for " + << role; + return success(); + }; + + if (failed(requireVecRowMajor(op.getSrc(), "src")) || + failed(requireVecRowMajor(op.getDst(), "dst")) || + failed(requireVecRowMajor(op.getIdx(), "idx"))) + return failure(); + + Type dataType = getElementType(op.getSrc()); + if (dataType != getElementType(op.getDst())) + return op.emitOpError("tsort32 lowering requires matching src/dst element type"); + if (!(dataType.isF16() || dataType.isF32())) + return op.emitOpError("tsort32 lowering currently supports only f16/f32 data"); + auto idxType = dyn_cast(getElementType(op.getIdx())); + if (!idxType || idxType.getWidth() != 32 || !idxType.isUnsigned()) + return op.emitOpError("tsort32 lowering currently requires u32 index tile"); + + Value srcBuffer = + materializeBufferPointer(op.getSrc(), dataType, getMemorySpace(op.getSrc()), + rewriter, op.getLoc()); + Value dstBuffer = + materializeBufferPointer(op.getDst(), dataType, getMemorySpace(op.getDst()), + rewriter, op.getLoc()); + Value idxBuffer = materializeBufferPointer(op.getIdx(), getElementType(op.getIdx()), + getMemorySpace(op.getIdx()), rewriter, + op.getLoc()); + if (!srcBuffer || !dstBuffer || !idxBuffer) + return op.emitOpError("tsort32 lowering requires pointer-backed tiles"); + + int64_t srcStride = deriveStaticRowStride(op.getSrc()); + int64_t dstStride = deriveStaticRowStride(op.getDst()); + int64_t idxStride = deriveStaticRowStride(op.getIdx()); + if (srcStride == ShapedType::kDynamic || dstStride == ShapedType::kDynamic || + idxStride == ShapedType::kDynamic) + return op.emitOpError("tsort32 lowering requires static row stride"); + + Value validRowsValue; + Value validColsValue; + int64_t validRows = ShapedType::kDynamic; + int64_t validCols = ShapedType::kDynamic; + deriveValidShapeValues(op.getSrc(), validRowsValue, validColsValue); + deriveValidShape(op.getSrc(), validRows, validCols); + if (validCols == ShapedType::kDynamic || (validCols % 32) != 0) + return op.emitOpError("tsort32 lowering currently requires static validCol divisible by 32"); + + int64_t idxValidRows = ShapedType::kDynamic; + int64_t idxValidCols = ShapedType::kDynamic; + deriveValidShape(op.getIdx(), idxValidRows, idxValidCols); + bool idxBroadcast = idxValidRows == 1; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value repeatNumPerRow = + rewriter.create(op.getLoc(), validCols / 32); + Value srcStrideValue = rewriter.create(op.getLoc(), srcStride); + Value dstStrideValue = rewriter.create(op.getLoc(), dstStride); + Value idxStrideValue = + rewriter.create(op.getLoc(), idxBroadcast ? 0 : idxStride); + + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value srcOffset = rewriter.create(op.getLoc(), row, srcStrideValue); + Value dstOffset = rewriter.create(op.getLoc(), row, dstStrideValue); + Value idxOffset = rewriter.create(op.getLoc(), row, idxStrideValue); + Value rowSrcPtr = + offsetBufferPointer(srcBuffer, dataType, srcOffset, rewriter, op.getLoc()); + Value rowDstPtr = + offsetBufferPointer(dstBuffer, dataType, dstOffset, rewriter, op.getLoc()); + Value rowIdxPtr = offsetBufferPointer(idxBuffer, getElementType(op.getIdx()), idxOffset, + rewriter, op.getLoc()); + rewriter.create(op.getLoc(), rowDstPtr, rowSrcPtr, rowIdxPtr, + repeatNumPerRow); + return success(); +} + +LogicalResult lowerTMulS(TMulSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("muls", op.getSrc0()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, and 16/32-bit integer element types"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmuls lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("muls", contract, strategy, op.getSrc0(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSelS(TSelSOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("sels", op.getSrc()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + if (failed(checkGenericBinaryContract( + op, contract, op.getTmp(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + + auto selectModeType = dyn_cast(op.getScalar().getType()); + if (!selectModeType) + return op.emitOpError("tsels lowering requires integer selectMode"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("tsels lowering requires a supported VPTO vector element type"); + + Value src0Buffer = materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getTmp(), contract.elementType, + getMemorySpace(op.getTmp()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tsels lowering requires pointer-backed tile buffers"); + + Value validRowsValue = materializeIndexValue(contract.validRowsValue, + contract.validRows, rewriter, op.getLoc()); + Value validColsValue = materializeIndexValue(contract.validColsValue, + contract.validCols, rewriter, op.getLoc()); + if (!validRowsValue || !validColsValue) + return op.emitOpError("tsels lowering requires valid rows and cols"); + + int64_t vectorWidth = vecType.getElementCount(); + if (contract.validRows != ShapedType::kDynamic && + contract.validCols != ShapedType::kDynamic) { + int64_t totalElements = contract.validRows * contract.validCols; + if (totalElements % vectorWidth != 0) + return op.emitOpError( + "tsels lowering currently requires total valid elements divisible by vector width"); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value totalElementsValue = + rewriter.create(op.getLoc(), validRowsValue, validColsValue); + Value vectorStepValue = + rewriter.create(op.getLoc(), vectorWidth); + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + + Value selectOne = rewriter.create( + op.getLoc(), IntegerAttr::get(selectModeType, 1)); + Value isAll = rewriter.create(op.getLoc(), arith::CmpIPredicate::eq, + op.getScalar(), selectOne); + auto ifOp = rewriter.create( + op.getLoc(), TypeRange{getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType)}, isAll, + /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value allMask = rewriter + .create(op.getLoc(), + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + rewriter.create(op.getLoc(), allMask); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value allfMask = rewriter + .create(op.getLoc(), + getVPTOMaskTypeForElementType(rewriter.getContext(), contract.elementType), + rewriter.getStringAttr("PAT_ALLF")) + .getResult(); + rewriter.create(op.getLoc(), allfMask); + + rewriter.setInsertionPointAfter(ifOp); + auto chunkLoop = + rewriter.create(op.getLoc(), c0, totalElementsValue, vectorStepValue); + rewriter.setInsertionPointToStart(chunkLoop.getBody()); + Value offset = chunkLoop.getInductionVar(); + Value mask = ifOp.getResult(0); + auto src0Vec = rewriter.create(op.getLoc(), vecType, src0Buffer, + offset, StringAttr()); + auto src1Vec = rewriter.create(op.getLoc(), vecType, src1Buffer, + offset, StringAttr()); + Value selected = rewriter + .create(op.getLoc(), vecType, src0Vec.getResult(), + src1Vec.getResult(), mask) + .getResult(); + rewriter.create( + op.getLoc(), selected, dstBuffer, offset, StringAttr(), + buildAllPredicateMask(rewriter, op.getLoc(), contract.elementType)); + return success(); +} + +LogicalResult lowerTSel(TSelOp op, PatternRewriter &rewriter) { + VPTOBinaryContract contract = buildBinaryContract("tsel", op.getSrc0()); + deriveValidShapeValues(op.getDst(), contract.validRowsValue, contract.validColsValue); + deriveValidShape(op.getDst(), contract.validRows, contract.validCols); + + int64_t src1Rows = ShapedType::kDynamic; + int64_t src1Cols = ShapedType::kDynamic; + int64_t dstRows = ShapedType::kDynamic; + int64_t dstCols = ShapedType::kDynamic; + int64_t maskRows = ShapedType::kDynamic; + int64_t maskCols = ShapedType::kDynamic; + deriveValidShape(op.getSrc1(), src1Rows, src1Cols); + deriveValidShape(op.getDst(), dstRows, dstCols); + deriveValidShape(op.getMask(), maskRows, maskCols); + + if (contract.tileDomain != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getMask())) != VPTOTileDomain::Vec) + return op.emitOpError("tsel lowering requires tile domain vec"); + if (contract.tileLayout != "row_major" || deriveTileLayout(op.getSrc1()) != "row_major" || + deriveTileLayout(op.getDst()) != "row_major" || deriveTileLayout(op.getMask()) != "row_major") + return op.emitOpError("tsel lowering requires row-major tile layout"); + if (contract.validRows == ShapedType::kDynamic || + contract.validCols == ShapedType::kDynamic) + return op.emitOpError("tsel lowering requires static valid shape"); + if (contract.validRows != src1Rows || contract.validCols != src1Cols || + contract.validRows != dstRows || contract.validCols != dstCols || + contract.validRows != maskRows || contract.validCols != maskCols) + return op.emitOpError("tsel lowering requires matching source, mask, and destination valid region"); + if (!contract.elementType || !contract.elementType.isF32()) + return op.emitOpError("tsel lowering currently supports only f32 data tiles"); + auto maskElemType = dyn_cast_or_null(getElementType(op.getMask())); + if (!maskElemType || maskElemType.getWidth() != 8) + return op.emitOpError("tsel lowering currently requires i8 mask tiles"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getDst()); + auto [maskTileRows, maskTileCols] = getStaticTileRowsCols(op.getMask()); + if (tileRows == ShapedType::kDynamic || tileCols == ShapedType::kDynamic || + maskTileRows == ShapedType::kDynamic || maskTileCols == ShapedType::kDynamic) + return op.emitOpError("tsel lowering requires static tile rows and cols"); + Value maskBuffer = materializeBufferPointer(op.getMask(), getElementType(op.getMask()), + getMemorySpace(op.getMask()), rewriter, + op.getLoc()); + Value src0Buffer = materializeBufferPointer(op.getSrc0(), contract.elementType, + getMemorySpace(op.getSrc0()), rewriter, + op.getLoc()); + Value src1Buffer = materializeBufferPointer(op.getSrc1(), contract.elementType, + getMemorySpace(op.getSrc1()), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), contract.elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!maskBuffer || !src0Buffer || !src1Buffer || !dstBuffer) + return op.emitOpError("tsel lowering requires pointer-backed tile buffers"); + + auto vecType = getVPTOVRegType(rewriter.getContext(), contract.elementType); + if (!vecType) + return op.emitOpError("tsel lowering requires a supported VPTO vector element type"); + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value validRowsValue = materializeIndexValue(contract.validRowsValue, contract.validRows, + rewriter, op.getLoc()); + if (!validRowsValue) + return op.emitOpError("tsel lowering requires valid rows"); + Value rowStride = rewriter.create(op.getLoc(), tileCols); + Value maskStride = rewriter.create(op.getLoc(), maskTileCols); + constexpr int64_t elementsPerRepeat = 64; + constexpr int64_t unrollConstant = 2; + int64_t repeatTimes = (contract.validCols + elementsPerRepeat - 1) / elementsPerRepeat; + int64_t pairedRepeatTimes = repeatTimes / unrollConstant; + int64_t remainRepeat = repeatTimes % unrollConstant; + int64_t repeatIdxBase = pairedRepeatTimes * unrollConstant; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), contract.loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto splitMaskType = getVPTOMaskType(rewriter.getContext(), "b16"); + Value fullMask = rewriter + .create(op.getLoc(), splitMaskType, + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + auto rowLoop = rewriter.create(op.getLoc(), c0, validRowsValue, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value rowBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), rowStride); + Value maskBase = rewriter.create(op.getLoc(), rowLoop.getInductionVar(), maskStride); + + for (int64_t j = 0; j < pairedRepeatTimes; ++j) { + int64_t repeatIdx = j * unrollConstant; + int64_t colOffset0 = repeatIdx * elementsPerRepeat; + int64_t colOffset1 = colOffset0 + elementsPerRepeat; + int64_t maskOffsetImm = repeatIdx * 8; + int64_t count0 = std::min(elementsPerRepeat, contract.validCols - colOffset0); + int64_t count1 = std::min(elementsPerRepeat, contract.validCols - colOffset1); + + Value maskOffset = rewriter.create( + op.getLoc(), maskBase, + rewriter.create(op.getLoc(), maskOffsetImm)); + Value rawMask = rewriter + .create(op.getLoc(), + splitMaskType, + maskBuffer, maskOffset, + rewriter.getStringAttr("US")) + .getResult(); + auto splitMask = rewriter.create( + op.getLoc(), splitMaskType, splitMaskType, rawMask, fullMask); + + Value dataOffset0 = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset0)); + auto lhs0 = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset0, StringAttr()); + auto rhs0 = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset0, StringAttr()); + Value selected0 = rewriter + .create(op.getLoc(), vecType, lhs0.getResult(), + rhs0.getResult(), splitMask.getLow()) + .getResult(); + Value storeMask0 = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count0)); + rewriter.create(op.getLoc(), selected0, dstBuffer, dataOffset0, + StringAttr(), storeMask0); + + Value dataOffset1 = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset1)); + auto lhs1 = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset1, StringAttr()); + auto rhs1 = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset1, StringAttr()); + Value selected1 = rewriter + .create(op.getLoc(), vecType, lhs1.getResult(), + rhs1.getResult(), splitMask.getHigh()) + .getResult(); + Value storeMask1 = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count1)); + rewriter.create(op.getLoc(), selected1, dstBuffer, dataOffset1, + StringAttr(), storeMask1); + } + + for (int64_t j = 0; j < remainRepeat; ++j) { + int64_t repeatIdx = repeatIdxBase + j; + int64_t colOffset = repeatIdx * elementsPerRepeat; + int64_t count = std::max(0, contract.validCols - colOffset); + int64_t maskOffsetImm = repeatIdx * 8; + + Value maskOffset = rewriter.create( + op.getLoc(), maskBase, + rewriter.create(op.getLoc(), maskOffsetImm)); + Value rawMask = rewriter + .create(op.getLoc(), + splitMaskType, + maskBuffer, maskOffset, + rewriter.getStringAttr("US")) + .getResult(); + Value unpackedMask = rewriter + .create( + op.getLoc(), splitMaskType, + rawMask, rewriter.getStringAttr("LOWER")) + .getResult(); + Value dataOffset = rewriter.create( + op.getLoc(), rowBase, + rewriter.create(op.getLoc(), colOffset)); + auto lhs = rewriter.create(op.getLoc(), vecType, src0Buffer, + dataOffset, StringAttr()); + auto rhs = rewriter.create(op.getLoc(), vecType, src1Buffer, + dataOffset, StringAttr()); + Value selected = rewriter + .create(op.getLoc(), vecType, lhs.getResult(), + rhs.getResult(), unpackedMask) + .getResult(); + Value storeMask = buildPredicateMaskForLaneCount( + rewriter, op.getLoc(), contract.elementType, + rewriter.create(op.getLoc(), count)); + rewriter.create(op.getLoc(), selected, dstBuffer, dataOffset, + StringAttr(), storeMask); + } + return success(); +} + +LogicalResult lowerTDivS(TDivSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + Value tileOperand; + Value scalarOperand; + bool scalarFirst = false; + if (isVPTOShapedLikeValue(op.getSrc()) && !isVPTOShapedLikeValue(op.getScalar())) { + tileOperand = op.getSrc(); + scalarOperand = op.getScalar(); + } else if (!isVPTOShapedLikeValue(op.getSrc()) && + isVPTOShapedLikeValue(op.getScalar())) { + tileOperand = op.getScalar(); + scalarOperand = op.getSrc(); + scalarFirst = true; + } else { + return op.emitOpError( + "divs lowering requires exactly one shaped operand and one scalar operand"); + } + + VPTOUnaryContract contract = buildUnaryContract("divs", tileOperand); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF16() || type.isF32(); }, + "f16 and f32 element types"))) + return failure(); + if (scalarOperand.getType() != contract.elementType) + return op.emitOpError( + "divs lowering requires scalar type to match source element type"); + return buildScalarDivVecScope(contract, strategy, tileOperand, scalarOperand, op.getDst(), + scalarFirst, rewriter, op.getLoc()); +} + +LogicalResult lowerTAddS(TAddSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("adds", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 16 || intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 16/32-bit integer element types"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tadds lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("adds", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTAddC(TAddCOp op, PatternRewriter &rewriter) { + VPTOBinaryContract first = buildBinaryContract("add", op.getSrc0()); + deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); + deriveValidShape(op.getDst(), first.validRows, first.validCols); + if (failed(checkGenericBinaryContract( + op, first, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + if (failed(buildBinaryVecScope("add", first, VPTOLoweringStrategy::PostUpdate, + op.getSrc0(), op.getSrc1(), op.getDst(), + rewriter, op.getLoc()))) + return failure(); + + VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); + deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); + deriveValidShape(op.getDst(), second.validRows, second.validCols); + if (failed(checkGenericBinaryContract( + op, second, op.getSrc2(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, + op.getDst(), op.getSrc2(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTAddSC(TAddSCOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "taddsc"); +} + +LogicalResult lowerTSubC(TSubCOp op, PatternRewriter &rewriter) { + VPTOBinaryContract first = buildBinaryContract("sub", op.getSrc0()); + deriveValidShapeValues(op.getDst(), first.validRowsValue, first.validColsValue); + deriveValidShape(op.getDst(), first.validRows, first.validCols); + if (failed(checkGenericBinaryContract( + op, first, op.getSrc1(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + if (failed(buildBinaryVecScope("sub", first, VPTOLoweringStrategy::PostUpdate, + op.getSrc0(), op.getSrc1(), op.getDst(), + rewriter, op.getLoc()))) + return failure(); + + VPTOBinaryContract second = buildBinaryContract("add", op.getDst()); + deriveValidShapeValues(op.getDst(), second.validRowsValue, second.validColsValue); + deriveValidShape(op.getDst(), second.validRows, second.validCols); + if (failed(checkGenericBinaryContract( + op, second, op.getSrc2(), op.getDst(), + [](Type type) { + if (type.isF16() || type.isF32() || type.isBF16()) + return true; + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return false; + }, + "f16, f32, bf16, and 8/16/32-bit integer element types"))) + return failure(); + return buildBinaryVecScope("add", second, VPTOLoweringStrategy::PostUpdate, + op.getDst(), op.getSrc2(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTSubS(TSubSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + (void)rewriter; + (void)strategy; + return emitUnresolvedInstalledA5BaselineError(op, "tsubs"); +} + +LogicalResult lowerTSubSC(TSubSCOp op, PatternRewriter &rewriter) { + return emitUnresolvedInstalledA5BaselineError(op, "tsubsc"); +} + +LogicalResult lowerTMaxS(TMaxSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("maxs", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF32(); }, "f32 element type"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmaxs lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("maxs", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTMinS(TMinSOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOUnaryContract contract = buildUnaryContract("mins", op.getSrc()); + if (failed(checkGenericUnaryContract( + op, contract, op.getDst(), + [](Type type) { return type.isF32(); }, "f32 element type"))) + return failure(); + if (op.getScalar().getType() != contract.elementType) + return op.emitOpError("tmins lowering requires scalar type to match source element type"); + return buildScalarUnaryVecScope("mins", contract, strategy, op.getSrc(), op.getScalar(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTRowMax(TRowMaxOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowMaxContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowmax", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRowMin(TRowMinOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowMinContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowmin", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTRowSum(TRowSumOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTORowReduceContract contract = extractTRowSumContract(op); + if (failed(checkRowReduceContract(op, contract, op.getDst()))) + return failure(); + return buildRowReduceVecScope("rowsum", contract, strategy, op.getSrc(), op.getDst(), + rewriter, op.getLoc()); +} + +LogicalResult lowerTColMax(TColMaxOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColMaxContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colmax", contract, op.getSrc(), op.getDst(), + Value(), rewriter, op.getLoc()); +} + +LogicalResult lowerTColMin(TColMinOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColMinContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colmin", contract, op.getSrc(), op.getDst(), + Value(), rewriter, op.getLoc()); +} + +LogicalResult lowerTColSum(TColSumOp op, PatternRewriter &rewriter) { + VPTOColReduceContract contract = extractTColSumContract(op); + if (failed(checkColReduceContract(op, contract, op.getDst()))) + return failure(); + return buildColReduceVecScope("colsum", contract, op.getSrc(), op.getDst(), + op.getTmp(), rewriter, op.getLoc()); +} + +LogicalResult lowerTRowExpand(TRowExpandOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + VPTOExpandContract contract = extractTRowExpandContract(op); + if (failed(checkExpandContract(op, contract))) + return failure(); + if (contract.srcValidRows != contract.dstValidRows) + return op.emitOpError() + << "rowexpand lowering requires source and destination valid rows to match"; + return buildRowExpandVecScope(contract, strategy, op.getSrc(), op.getDst(), rewriter, + op.getLoc()); +} + +LogicalResult lowerTColExpand(TColExpandOp op, PatternRewriter &rewriter) { + VPTOExpandContract contract = extractTColExpandContract(op); + if (failed(checkExpandContract(op, contract))) + return failure(); + if (contract.srcValidCols != contract.dstValidCols) + return op.emitOpError() + << "colexpand lowering requires source and destination valid cols to match"; + return buildColExpandVecScope(contract, op.getSrc(), op.getDst(), rewriter, + op.getLoc()); +} + +template +LogicalResult lowerTRowExpandBinaryLike(OpTy op, PatternRewriter &rewriter, + StringRef family, + VPTOLoweringStrategy strategy) { + Type elementType = getElementType(op.getDst()); + if (!elementType || (!elementType.isF16() && !elementType.isF32())) + return op.emitOpError() << family + << " lowering currently supports only f16 and f32 element types"; + + if (deriveTileDomain(getMemorySpace(op.getDst())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc0())) != VPTOTileDomain::Vec || + deriveTileDomain(getMemorySpace(op.getSrc1())) != VPTOTileDomain::Vec) + return op.emitOpError() << family << " lowering requires vec tile domain"; + if (deriveTileLayout(op.getDst()) != "row_major") + return op.emitOpError() << family << " lowering requires row-major dst layout"; + + int64_t dstValidRows = ShapedType::kDynamic; + int64_t dstValidCols = ShapedType::kDynamic; + int64_t src0ValidRows = ShapedType::kDynamic; + int64_t src0ValidCols = ShapedType::kDynamic; + int64_t src1ValidRows = ShapedType::kDynamic; + int64_t src1ValidCols = ShapedType::kDynamic; + deriveValidShape(op.getDst(), dstValidRows, dstValidCols); + deriveValidShape(op.getSrc0(), src0ValidRows, src0ValidCols); + deriveValidShape(op.getSrc1(), src1ValidRows, src1ValidCols); + if (dstValidRows == ShapedType::kDynamic || dstValidCols == ShapedType::kDynamic || + src0ValidRows == ShapedType::kDynamic || src0ValidCols == ShapedType::kDynamic || + src1ValidRows == ShapedType::kDynamic || src1ValidCols == ShapedType::kDynamic) + return op.emitOpError() << family + << " lowering currently requires static valid shapes"; + + bool src0EqDst = op.getSrc0().getType() == op.getDst().getType(); + bool src1EqDst = op.getSrc1().getType() == op.getDst().getType(); + if (!src0EqDst && !src1EqDst) + return op.emitOpError() << family + << " lowering requires src0 or src1 to match dst tile type"; + + Value baseSrc = src0EqDst ? op.getSrc0() : op.getSrc1(); + Value expandSrc = src0EqDst ? op.getSrc1() : op.getSrc0(); + StringRef expandLayout = deriveTileLayout(expandSrc); + int64_t expandValidRows = src0EqDst ? src1ValidRows : src0ValidRows; + int64_t expandValidCols = src0EqDst ? src1ValidCols : src0ValidCols; + if (expandValidRows != dstValidRows) + return op.emitOpError() << family + << " lowering requires expand operand valid rows to match dst"; + + int64_t elemBytes = getElementByteSize(elementType); + bool expandIsRowMajor = expandLayout == "row_major" && expandValidCols == 32 / elemBytes; + bool expandIsColMajor = expandLayout == "col_major" && expandValidCols == 1; + if (!expandIsRowMajor && !expandIsColMajor) + return op.emitOpError() << family + << " lowering requires PTO A5-compatible expand operand shape"; + + auto vecType = getVPTOVRegType(rewriter.getContext(), elementType); + if (!vecType) + return op.emitOpError() << family + << " lowering requires a legal VPTO vector type"; + + Value baseBuffer = materializeBufferPointer(baseSrc, elementType, + getMemorySpace(baseSrc), rewriter, + op.getLoc()); + Value expandBuffer = materializeBufferPointer(expandSrc, elementType, + getMemorySpace(expandSrc), rewriter, + op.getLoc()); + Value dstBuffer = materializeBufferPointer(op.getDst(), elementType, + getMemorySpace(op.getDst()), rewriter, + op.getLoc()); + if (!baseBuffer || !expandBuffer || !dstBuffer) + return op.emitOpError() << family + << " lowering requires pointer-backed tile buffers"; + + int64_t dstRowStride = deriveStaticRowStride(op.getDst()); + int64_t baseRowStride = deriveStaticRowStride(baseSrc); + int64_t expandRowStride = deriveStaticRowStride(expandSrc); + if (dstRowStride == ShapedType::kDynamic || baseRowStride == ShapedType::kDynamic || + expandRowStride == ShapedType::kDynamic) + return op.emitOpError() << family << " lowering requires static row strides"; + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value rowsUpper = rewriter.create(op.getLoc(), dstValidRows); + Value colsUpper = rewriter.create(op.getLoc(), dstValidCols); + Value vectorStep = + rewriter.create(op.getLoc(), vecType.getElementCount()); + Value baseStrideValue = + rewriter.create(op.getLoc(), baseRowStride); + Value expandStrideValue = + rewriter.create(op.getLoc(), expandRowStride); + Value dstStrideValue = + rewriter.create(op.getLoc(), dstRowStride); + Value blockSizeValue = + rewriter.create(op.getLoc(), 32 / elemBytes); + + VPTOLoopScopeContract loopScope; + loopScope.kind = VPTOLoopScopeKind::AIVVectorScope; + loopScope.loweredAttr = kLoweredLoopScopeAttrName; + loopScope.loopDepth = 0; + + auto buildRowExpandValue = [&](Value baseVec, Value expandedVec, + Value predicate) -> FailureOr { + if (family == "trowexpandmul") + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + if (family == "trowexpanddiv") { + if (src0EqDst) + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + return rewriter.create(op.getLoc(), vecType, expandedVec, + baseVec, predicate) + .getResult(); + } + if (family == "trowexpandsub") { + if (src0EqDst) + return rewriter.create(op.getLoc(), vecType, baseVec, + expandedVec, predicate) + .getResult(); + return rewriter.create(op.getLoc(), vecType, expandedVec, + baseVec, predicate) + .getResult(); + } + return failure(); + }; + + FailureOr vecScope = + createLoopScopeRegion(op.getLoc(), loopScope, rewriter); + if (failed(vecScope)) + return op.emitOpError("failed to create AIV vector scope region"); + + OpBuilder::InsertionGuard aivGuard(rewriter); + rewriter.setInsertionPointToStart(&(*vecScope).getBody().front()); + auto rowLoop = rewriter.create(op.getLoc(), c0, rowsUpper, c1); + rewriter.setInsertionPointToStart(rowLoop.getBody()); + Value row = rowLoop.getInductionVar(); + Value baseRowOffset = rewriter.create(op.getLoc(), row, baseStrideValue); + Value dstRowOffset = rewriter.create(op.getLoc(), row, dstStrideValue); + Value expandRowOffset = expandIsRowMajor + ? rewriter.create(op.getLoc(), row, blockSizeValue) + : rewriter.create(op.getLoc(), row, expandStrideValue); + + Value expandVec; + if (expandIsColMajor) { + Value fullMask = buildAllPredicateMask(rewriter, op.getLoc(), elementType); + Value expandScalar = + rewriter.create(op.getLoc(), vecType, expandBuffer, + expandRowOffset); + expandVec = rewriter + .create(op.getLoc(), vecType, expandScalar, fullMask, + StringAttr()) + .getResult(); + } else { + expandVec = rewriter + .create(op.getLoc(), vecType, expandBuffer, expandRowOffset, + rewriter.getStringAttr("BLK")) + .getResult(); + } + + auto colLoop = rewriter.create(op.getLoc(), c0, colsUpper, vectorStep); + rewriter.setInsertionPointToStart(colLoop.getBody()); + Value col = colLoop.getInductionVar(); + Value remainingCols = rewriter.create(op.getLoc(), colsUpper, col); + Value needsTailMask = rewriter.create( + op.getLoc(), arith::CmpIPredicate::slt, remainingCols, vectorStep); + Value activeLanes = rewriter.create(op.getLoc(), needsTailMask, + remainingCols, vectorStep); + Value baseOffset = rewriter.create(op.getLoc(), baseRowOffset, col); + Value dstOffset = rewriter.create(op.getLoc(), dstRowOffset, col); + Value storeMask = + buildPredicateMaskForLaneCount(rewriter, op.getLoc(), elementType, activeLanes); + Value baseVec = + rewriter.create(op.getLoc(), vecType, baseBuffer, baseOffset, StringAttr()); + FailureOr computed = + buildRowExpandValue(baseVec, expandVec, storeMask); + if (failed(computed)) + return op.emitOpError() << "unsupported rowexpand binary family"; + rewriter.create(op.getLoc(), *computed, dstBuffer, dstOffset, + StringAttr(), storeMask); + rewriter.create(op.getLoc()); + return success(); +} + +LogicalResult lowerTRowExpandMul(TRowExpandMulOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandmul", strategy); +} + +LogicalResult lowerTRowExpandDiv(TRowExpandDivOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpanddiv", strategy); +} + +LogicalResult lowerTRowExpandSub(TRowExpandSubOp op, PatternRewriter &rewriter, + VPTOLoweringStrategy strategy) { + return lowerTRowExpandBinaryLike(op, rewriter, "trowexpandsub", strategy); +} + +LogicalResult lowerTPartAdd(TPartAddOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartAddContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partadd", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTPartMax(TPartMaxOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartMaxContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partmax", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTPartMin(TPartMinOp op, PatternRewriter &rewriter) { + VPTOPartContract contract = extractTPartMinContract(op); + if (failed(checkPartContract(op, contract))) + return failure(); + return buildPartVecScope("partmin", contract, op.getSrc0(), op.getSrc1(), + op.getDst(), rewriter, op.getLoc()); +} + +LogicalResult lowerTSTORE(TStoreOp op, PatternRewriter &rewriter) { + VPTOStoreContract contract = extractTStoreContract(op); + + switch (contract.srcDomain) { + case VPTOTileDomain::Acc: + return lowerUnsupportedAccStore(op.getLoc()); + case VPTOTileDomain::Mat: + return lowerUnsupportedMatStore(op.getLoc()); + case VPTOTileDomain::Vec: + break; + } + + ResolvedTensorView destinationView; + if (!resolveTensorView(op.getDst(), destinationView, rewriter, op.getLoc())) + return op.emitOpError("requires a recoverable destination tensor view for VPTO lowering"); + + StringRef sourceTileLayout = deriveTileLayout(op.getSrc()); + StringRef destinationLayout = + inferVecTransferLayoutFromTile(stringifyLayoutAttr(destinationView.layoutAttr), + sourceTileLayout); + bool isNdStore = sourceTileLayout == "row_major" && destinationLayout == "nd"; + bool isDnStore = sourceTileLayout == "col_major" && destinationLayout == "dn"; + if (!isNdStore && !isDnStore) + return op.emitOpError("currently supports only ND row_major or DN col_major vec TSTORE lowering"); + + Value sourceBuffer = + materializeBufferPointer(op.getSrc(), contract.elementType, + getMemorySpace(op.getSrc()), rewriter, op.getLoc()); + Value destinationBuffer = + materializeBufferPointer(destinationView.root, getElementType(destinationView.root), + getGmMemorySpace(rewriter.getContext()), rewriter, + op.getLoc()); + if (!sourceBuffer || !destinationBuffer) + return op.emitOpError("requires A5-compatible source and destination buffers"); + + auto [tileRows, tileCols] = getStaticTileRowsCols(op.getSrc()); + Value validRowsValue = + materializeI64Value(contract.validRowsValue, contract.validRows, rewriter, + op.getLoc()); + Value validColsValue = + materializeI64Value(contract.validColsValue, contract.validCols, rewriter, + op.getLoc()); + Value sidValue = rewriter.create(op.getLoc(), 0, 64); + int64_t elemBytes = getElementByteSize(contract.elementType); + if ((isNdStore && tileCols == ShapedType::kDynamic) || + (isDnStore && tileRows == ShapedType::kDynamic) || elemBytes <= 0) + return op.emitOpError("requires static tile shape for A5-compatible transfer arguments"); + VecNdTransferPlan plan; + LogicalResult planResult = + isNdStore ? buildVecNdStorePlan(destinationView.shape, destinationView.strides, + tileCols, contract.validColsValue, + contract.validCols, contract.elementType, + rewriter, op.getLoc(), plan) + : buildVecDnStorePlan(destinationView.shape, destinationView.strides, + tileRows, contract.validRowsValue, + contract.validRows, contract.elementType, + rewriter, op.getLoc(), plan); + if (failed(planResult)) + return op.emitOpError("requires PTO-compatible vec copy_ubuf_to_gm arguments"); + Value reservedValue = rewriter.create(op.getLoc(), 0, 64); + if (!validRowsValue || !validColsValue) + return op.emitOpError("requires valid rows and cols for A5-compatible transfer arguments"); + Value destinationOffset = + materializeI64Ofr(destinationView.offsetElems, rewriter, op.getLoc()); + if (!destinationOffset) + return op.emitOpError("requires a materializable destination offset for VPTO lowering"); + Value destinationBase = + adjustPointerByElemOffset(destinationBuffer, destinationOffset, elemBytes, rewriter, + op.getLoc()); + if (!destinationBase) + return op.emitOpError("failed to materialize destination base pointer"); + + rewriter.create(op.getLoc(), plan.loop2Size, + plan.loop1Size); + rewriter.create( + op.getLoc(), plan.loop1FirstStrideBytes, plan.loop1SecondStrideBytes); + rewriter.create( + op.getLoc(), plan.loop2FirstStrideBytes, plan.loop2SecondStrideBytes); + + auto emitCopy = [&](Value srcPtr, Value dstPtr) { + Type transferElementType = + getCopyTransferElementType(contract.elementType, rewriter); + Value typedSrcPtr = + castPtrToElementType(srcPtr, transferElementType, rewriter, op.getLoc()); + Value typedDstPtr = + castPtrToElementType(dstPtr, transferElementType, rewriter, op.getLoc()); + if (!typedSrcPtr || !typedDstPtr) + return failure(); + rewriter.create( + op.getLoc(), typedSrcPtr, typedDstPtr, sidValue, plan.nBurst, + plan.lenBurst, reservedValue, plan.firstStrideBytes, + plan.secondStrideBytes); + return success(); + }; + + if (std::optional outerConst = getConstInt(plan.outerCount); outerConst && *outerConst == 1) { + return emitCopy(sourceBuffer, destinationBase); + } + + Value c0 = rewriter.create(op.getLoc(), 0); + Value c1 = rewriter.create(op.getLoc(), 1); + Value outerUpper = + rewriter.create(op.getLoc(), rewriter.getIndexType(), + plan.outerCount); + auto outerLoop = rewriter.create(op.getLoc(), c0, outerUpper, c1); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value ivI64 = rewriter.create(op.getLoc(), rewriter.getI64Type(), + outerLoop.getInductionVar()); + Value srcStep = createI64Mul(ivI64, plan.outerSrcStrideElems, rewriter, op.getLoc()); + Value dstStep = createI64Mul(ivI64, plan.outerDstStrideElems, rewriter, op.getLoc()); + Value iterSrc = adjustPointerByElemOffset(sourceBuffer, srcStep, elemBytes, rewriter, + op.getLoc()); + Value iterDst = adjustPointerByElemOffset(destinationBase, dstStep, elemBytes, rewriter, + op.getLoc()); + return emitCopy(iterSrc, iterDst); +} + +LogicalResult lowerSetFlag(SetFlagOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getSrcPipe(), rewriter), + stringifyPipeAttr(op.getDstPipe(), rewriter), + stringifyEventAttr(op.getEventId(), rewriter)); + return success(); +} + +LogicalResult lowerWaitFlag(WaitFlagOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getSrcPipe(), rewriter), + stringifyPipeAttr(op.getDstPipe(), rewriter), + stringifyEventAttr(op.getEventId(), rewriter)); + return success(); +} + +LogicalResult lowerBarrier(BarrierOp op, PatternRewriter &rewriter) { + rewriter.create(op.getLoc(), + stringifyPipeAttr(op.getPipe(), rewriter)); + return success(); +} + +static FailureOr stringifyConcreteSyncPipeAttr(Attribute opTypeAttr, + PatternRewriter &rewriter) { + if (auto pipeAttr = dyn_cast(opTypeAttr)) + return PipeAttr::get(rewriter.getContext(), pipeAttr.getPipe()); + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) + return failure(); + PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return failure(); + return PipeAttr::get(rewriter.getContext(), pipe); +} + +LogicalResult lowerGetBuf(GetBufOp op, PatternRewriter &rewriter) { + FailureOr pipeAttr = + stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); + if (failed(pipeAttr)) + return op.emitOpError("get_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); + + rewriter.create(op.getLoc(), Attribute(*pipeAttr), + static_cast(op.getBufId()), + static_cast(op.getMode())); + return success(); +} + +LogicalResult lowerRlsBuf(RlsBufOp op, PatternRewriter &rewriter) { + FailureOr pipeAttr = + stringifyConcreteSyncPipeAttr(op.getOpTypeAttr(), rewriter); + if (failed(pipeAttr)) + return op.emitOpError("rls_buf expects SyncOpType/PipeEventType that maps to a concrete pipe"); + + rewriter.create(op.getLoc(), Attribute(*pipeAttr), + static_cast(op.getBufId()), + static_cast(op.getMode())); + return success(); +} + +namespace { + +static Type convertVPTOBoundaryMemRefType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + auto memorySpace = dyn_cast_or_null(memrefType.getMemorySpace()); + if (!memorySpace) + return {}; + return PtrType::get(type.getContext(), memrefType.getElementType(), memorySpace); +} + +static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + SmallVector deadOps; + module.walk([&](Operation *op) { + if (!op->use_empty()) + return; + if (isa(op)) + deadOps.push_back(op); + }); + for (Operation *op : deadOps) { + op->erase(); + erasedAny = true; + } + } + return success(); +} + +static LogicalResult verifyNoResidualVPTOMemRefs(ModuleOp module, + llvm::raw_ostream *diagOS) { + for (func::FuncOp func : module.getOps()) { + for (Type input : func.getFunctionType().getInputs()) { + if (!isa(input)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref argument in " + << func.getName() << ": " << input << "\n"; + return failure(); + } + for (Type result : func.getFunctionType().getResults()) { + if (!isa(result)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref result in " + << func.getName() << ": " << result << "\n"; + return failure(); + } + } + + WalkResult walk = module.walk([&](Operation *op) { + auto hasResidualMemRef = [](TypeRange types) { + return llvm::any_of(types, [](Type type) { + return isa(type); + }); + }; + if (hasResidualMemRef(op->getOperandTypes()) || + hasResidualMemRef(op->getResultTypes())) { + if (diagOS) { + *diagOS << "VPTO ptr-only boundary failed: residual memref-typed op " + << op->getName() << "\n"; + op->print(*diagOS); + *diagOS << "\n"; + } + return WalkResult::interrupt(); + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (!isa(arg.getType())) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: residual memref block " + << "argument in op " << op->getName() << ": " + << arg.getType() << "\n"; + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + return walk.wasInterrupted() ? failure() : success(); +} + +} // namespace + +LogicalResult convertVPTOFunctionBoundariesToPtr(ModuleOp module, + llvm::raw_ostream *diagOS) { + // VPTO kernels use ptr-only entry semantics: the function ABI keeps only the + // same-space base pointer, while shape/stride/offset stay in live SSA and + // address calculations inside the body. + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + + bool sawFailure = false; + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + + FunctionType functionType = func.getFunctionType(); + SmallVector newInputs(functionType.getInputs().begin(), + functionType.getInputs().end()); + bool changed = false; + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + auto memrefType = dyn_cast(inputType); + if (!memrefType) + continue; + + Type newType = convertVPTOBoundaryMemRefType(inputType); + if (!newType) { + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: unsupported memref " + << "argument type in " << func.getName() << ": " + << inputType << "\n"; + sawFailure = true; + continue; + } + + BlockArgument arg = func.getArgument(idx); + SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); + arg.setType(newType); + newInputs[idx] = newType; + changed = true; + + for (Operation *user : users) { + if (auto cast = dyn_cast(user)) { + if (cast.getInput() != arg) + continue; + if (cast.getResult().getType() == newType) { + cast.getResult().replaceAllUsesWith(arg); + cast.erase(); + } + continue; + } + + if (isa(user) && + user->use_empty()) { + user->erase(); + continue; + } + + if (diagOS) { + *diagOS << "VPTO ptr-only boundary failed: argument " << idx + << " of " << func.getName() + << " still feeds a memref-dependent user after ptr rewrite:\n"; + user->print(*diagOS); + *diagOS << "\n"; + } + sawFailure = true; + } + } + + for (Type resultType : functionType.getResults()) { + if (!isa(resultType)) + continue; + if (diagOS) + *diagOS << "VPTO ptr-only boundary failed: memref result is unsupported " + << "for " << func.getName() << ": " << resultType << "\n"; + sawFailure = true; + } + + if (changed) { + func.setFunctionType( + FunctionType::get(module.getContext(), newInputs, functionType.getResults())); + } + } + + if (sawFailure) + return failure(); + + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + return verifyNoResidualVPTOMemRefs(module, diagOS); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp new file mode 100644 index 000000000..7d2cd0d4d --- /dev/null +++ b/lib/PTO/Transforms/PTOVPTOExpandBridgeOps.cpp @@ -0,0 +1,114 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVPTOEXPANDBRIDGEOPS +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static pto::AddressSpaceAttr getPointerMemorySpace(Attribute memorySpace, + MLIRContext *ctx) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return pto::AddressSpaceAttr::get( + ctx, static_cast(intAttr.getInt())); + return pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); +} + +static Value materializeBufferPointer(Value value, PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + + if (isa(value.getType())) + return value; + + auto memrefType = dyn_cast(value.getType()); + if (!memrefType) + return {}; + + auto ptrType = + pto::PtrType::get(rewriter.getContext(), memrefType.getElementType(), + getPointerMemorySpace(memrefType.getMemorySpace(), + rewriter.getContext())); + return rewriter.create(loc, ptrType, value).getResult(); +} + +static Value offsetBufferPointer(Value basePtr, Type elementType, + Value elementOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + Value offsetIndex = elementOffset; + if (!offsetIndex.getType().isIndex()) + offsetIndex = rewriter.create(loc, + rewriter.getIndexType(), + elementOffset); + return rewriter.create(loc, basePtr.getType(), basePtr, + offsetIndex); +} + +struct ExpandUvldPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::UvldOp op, + PatternRewriter &rewriter) const override { + auto vecType = dyn_cast(op.getResult().getType()); + if (!vecType) + return failure(); + + Value basePtr = materializeBufferPointer(op.getSource(), rewriter, op.getLoc()); + if (!basePtr) + return op.emitOpError( + "requires a recoverable pointer base for uvld expansion"); + + Value loadPtr = offsetBufferPointer(basePtr, vecType.getElementType(), + op.getOffset(), rewriter, op.getLoc()); + auto alignType = pto::AlignType::get(rewriter.getContext()); + Value align = + rewriter.create(op.getLoc(), alignType, loadPtr); + auto load = rewriter.create( + op.getLoc(), TypeRange{vecType, alignType, loadPtr.getType()}, + ValueRange{loadPtr, align}); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOVPTOExpandBridgeOpsPass + : public pto::impl::PTOVPTOExpandBridgeOpsBase { + using pto::impl::PTOVPTOExpandBridgeOpsBase< + PTOVPTOExpandBridgeOpsPass>::PTOVPTOExpandBridgeOpsBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOVPTOExpandBridgeOpsPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp new file mode 100644 index 000000000..6aa62259f --- /dev/null +++ b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp @@ -0,0 +1,337 @@ +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVPTOPTRBOUNDARY +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static Type convertVPTOBoundaryMemRefType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + auto memorySpace = + dyn_cast_or_null(memrefType.getMemorySpace()); + if (!memorySpace) + return {}; + return pto::PtrType::get(type.getContext(), memrefType.getElementType(), + memorySpace); +} + +static bool isTrivialVPTOBoundaryCastPtr(pto::CastPtrOp castOp) { + return castOp.getInput().getType() == castOp.getResult().getType(); +} + +static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + SmallVector trivialCasts; + SmallVector deadOps; + module.walk([&](Operation *op) { + if (auto castOp = dyn_cast(op)) { + if (isTrivialVPTOBoundaryCastPtr(castOp)) { + trivialCasts.push_back(castOp); + return; + } + if (castOp->use_empty()) + deadOps.push_back(op); + return; + } + + if (!op->use_empty()) + return; + if (isa(op)) + deadOps.push_back(op); + }); + + for (pto::CastPtrOp castOp : trivialCasts) { + if (!castOp->getBlock()) + continue; + castOp.getResult().replaceAllUsesWith(castOp.getInput()); + castOp.erase(); + erasedAny = true; + } + + for (Operation *op : deadOps) { + if (!op->getBlock()) + continue; + op->erase(); + erasedAny = true; + } + } + return success(); +} + +static Type getVPTOBufferElementType(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + return {}; +} + +static Attribute getVPTOBufferMemorySpace(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getMemorySpace(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getMemorySpace(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); + return {}; +} + +static bool needsPtrCanonicalization(Value value) { + return isa(value.getType()); +} + +static bool isSupportedVPTOBufferLikeBoundaryOp(Operation *op) { + return isa(op); +} + +static LogicalResult canonicalizeBoundaryCastPtrOps(ModuleOp module, + llvm::raw_ostream *diagOS) { + SmallVector castsToRewrite; + module.walk([&](pto::CastPtrOp castOp) { + if (!isa(castOp.getInput().getType())) + return; + if (!isa(castOp.getResult().getType())) + return; + castsToRewrite.push_back(castOp); + }); + + PatternRewriter rewriter(module.getContext()); + for (pto::CastPtrOp castOp : castsToRewrite) { + if (!castOp->getBlock()) + continue; + + auto resultType = dyn_cast(castOp.getResult().getType()); + if (!resultType) + continue; + + rewriter.setInsertionPoint(castOp); + Value ptrValue = pto::materializeBufferPointer( + castOp.getInput(), resultType.getElementType(), + resultType.getMemorySpace(), rewriter, castOp.getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "canonicalize pto.castptr input for "; + castOp->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + castOp.getResult().replaceAllUsesWith(ptrValue); + rewriter.eraseOp(castOp); + } + + return success(); +} + +static LogicalResult canonicalizeSupportedVPTOBufferLikeOps( + ModuleOp module, llvm::raw_ostream *diagOS) { + SmallVector opsToRewrite; + module.walk([&](Operation *op) { + if (isSupportedVPTOBufferLikeBoundaryOp(op)) + opsToRewrite.push_back(op); + }); + + PatternRewriter rewriter(module.getContext()); + for (Operation *op : opsToRewrite) { + rewriter.setInsertionPoint(op); + + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + bool changed = false; + + for (Value operand : op->getOperands()) { + if (!needsPtrCanonicalization(operand)) { + newOperands.push_back(operand); + continue; + } + + Type elementType = getVPTOBufferElementType(operand); + Attribute memorySpace = getVPTOBufferMemorySpace(operand); + if (!elementType || !memorySpace) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "derive element type or memory space for operand of "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + Value ptrValue = pto::materializeBufferPointer(operand, elementType, + memorySpace, rewriter, + op->getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "materialize pointer operand for "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + changed = changed || (ptrValue != operand); + newOperands.push_back(ptrValue); + } + + if (!changed) + continue; + + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(newOperands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + } + + return success(); +} + +struct PTOVPTOPtrBoundaryPass + : public pto::impl::PTOVPTOPtrBoundaryBase { + using pto::impl::PTOVPTOPtrBoundaryBase< + PTOVPTOPtrBoundaryPass>::PTOVPTOPtrBoundaryBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(pto::convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult mlir::pto::convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS) { + // VPTO kernels use ptr-only entry semantics at the emission boundary: the + // function ABI keeps only the same-space base pointer, while shape/stride + // state remains in SSA. Body-level op canonicalization is added on top of + // this entry rewrite in follow-up tasks. + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + + bool sawFailure = false; + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + + FunctionType functionType = func.getFunctionType(); + SmallVector newInputs(functionType.getInputs().begin(), + functionType.getInputs().end()); + bool changed = false; + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + auto memrefType = dyn_cast(inputType); + if (!memrefType) + continue; + + Type newType = convertVPTOBoundaryMemRefType(inputType); + if (!newType) { + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: unsupported " + "memref argument type in " + << func.getName() << ": " << inputType << "\n"; + sawFailure = true; + continue; + } + + BlockArgument arg = func.getArgument(idx); + SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); + arg.setType(newType); + newInputs[idx] = newType; + changed = true; + + for (Operation *user : users) { + if (auto cast = dyn_cast(user)) { + if (cast.getInput() != arg) + continue; + if (cast.getResult().getType() == newType) { + cast.getResult().replaceAllUsesWith(arg); + cast.erase(); + } + continue; + } + + if (isa(user) && + user->use_empty()) { + user->erase(); + continue; + } + + if (isSupportedVPTOBufferLikeBoundaryOp(user)) + continue; + + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: argument " + << idx << " of " << func.getName() + << " still feeds a memref-dependent user after ptr rewrite:\n"; + user->print(*diagOS); + *diagOS << "\n"; + } + sawFailure = true; + } + } + + for (Type resultType : functionType.getResults()) { + if (!isa(resultType)) + continue; + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: memref result " + "is unsupported for " + << func.getName() << ": " << resultType << "\n"; + sawFailure = true; + } + + if (changed) { + func.setFunctionType( + FunctionType::get(module.getContext(), newInputs, functionType.getResults())); + } + } + + if (sawFailure) + return failure(); + + if (failed(canonicalizeBoundaryCastPtrOps(module, diagOS))) + return failure(); + + if (failed(canonicalizeSupportedVPTOBufferLikeOps(module, diagOS))) + return failure(); + + return eraseDeadVPTOMemRefScaffold(module); +} + +std::unique_ptr mlir::pto::createPTOVPTOPtrBoundaryPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp new file mode 100644 index 000000000..a81b4e81f --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -0,0 +1,756 @@ +//===- PTOValidateVPTOIR.cpp - Shared VPTO legality helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file owns the shared helper layer for the dual-stage VPTO legality +// verifier. Follow-up tasks add the public validation entrypoints and pass +// wrappers on top of this utility layer. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace mlir { +namespace pto { + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); + +namespace detail { + +constexpr llvm::StringLiteral kAIVectorScopeAttrName = + "llvm.loop.aivector_scope"; + +enum class VPTOMaskGranularity { + B8, + B16, + B32, +}; + +enum class VPTOBufferAddressFamily { + None, + Copy, + BufferLike, + PtrOnly, +}; + +enum class VPTOLegalityStage { + Authoring, + Emission, +}; + +class VPTOLegalityHelper { +public: + explicit VPTOLegalityHelper(ModuleOp module) : module(module) {} + + ModuleOp getModule() const { return module; } + + SmallVector getFunctions() { + SmallVector funcs; + for (func::FuncOp func : module.getOps()) + funcs.push_back(func); + return funcs; + } + + static bool isLegalityTypedValue(Type type) { + return isa(type); + } + + static bool isBufferLikeValue(Type type) { + return isa(type); + } + + static bool requiresVecScope(Operation *op) { + if (!isPTOp(op)) + return false; + + return llvm::any_of(op->getOperandTypes(), isLegalityTypedValue) || + llvm::any_of(op->getResultTypes(), isLegalityTypedValue); + } + + static bool isAIVectorScopeCarrier(scf::ForOp loop) { + return loop && loop->hasAttr(kAIVectorScopeAttrName); + } + + static bool isDedicatedVecScopeCarrier(Operation *op) { + return isa_and_nonnull(op); + } + + static bool isAnyVectorScopeCarrier(Operation *op) { + if (auto loop = dyn_cast_or_null(op)) + return isAIVectorScopeCarrier(loop); + return isDedicatedVecScopeCarrier(op); + } + + static Operation *getEnclosingVectorScopeCarrier(Operation *op) { + for (Operation *parent = op ? op->getParentOp() : nullptr; parent; + parent = parent->getParentOp()) { + if (isAnyVectorScopeCarrier(parent)) + return parent; + } + return nullptr; + } + + static std::optional getMaskGranularity(Type type) { + auto maskType = dyn_cast(type); + if (!maskType) + return std::nullopt; + return getMaskGranularity(maskType); + } + + static std::optional getMaskGranularity(MaskType type) { + if (type.isB8()) + return VPTOMaskGranularity::B8; + if (type.isB16()) + return VPTOMaskGranularity::B16; + if (type.isB32()) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static StringRef stringifyMaskGranularity(VPTOMaskGranularity granularity) { + switch (granularity) { + case VPTOMaskGranularity::B8: + return "b8"; + case VPTOMaskGranularity::B16: + return "b16"; + case VPTOMaskGranularity::B32: + return "b32"; + } + llvm_unreachable("unsupported VPTO mask granularity"); + } + + static std::optional + inferMaskGranularityFromType(Type type) { + if (auto vregType = dyn_cast(type)) + type = vregType.getElementType(); + + if (type.isF32()) + return VPTOMaskGranularity::B32; + if (type.isF16() || type.isBF16()) + return VPTOMaskGranularity::B16; + + auto intType = dyn_cast(type); + if (!intType) + return std::nullopt; + + switch (intType.getWidth()) { + case 8: + return VPTOMaskGranularity::B8; + case 16: + return VPTOMaskGranularity::B16; + case 32: + return VPTOMaskGranularity::B32; + default: + return std::nullopt; + } + } + + static std::optional + inferMaskGranularityFromFamily(Operation *op) { + StringRef mnemonic = getPTOpMnemonic(op); + if (mnemonic.empty()) + return std::nullopt; + + if (mnemonic.ends_with("_b8")) + return VPTOMaskGranularity::B8; + if (mnemonic.ends_with("_b16")) + return VPTOMaskGranularity::B16; + if (mnemonic.ends_with("_b32")) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static VPTOBufferAddressFamily classifyBufferAddressFamily(Operation *op) { + if (!op) + return VPTOBufferAddressFamily::None; + + if (isa(op)) + return VPTOBufferAddressFamily::Copy; + + if (isa(op)) + return VPTOBufferAddressFamily::PtrOnly; + + if (isa(op)) + return VPTOBufferAddressFamily::BufferLike; + + return VPTOBufferAddressFamily::None; + } + + static bool isSupportedEmissionBufferLikeOp(Operation *op) { + return classifyBufferAddressFamily(op) == + VPTOBufferAddressFamily::BufferLike; + } + + static bool isResidualEmissionScaffold(Operation *op) { + return isa(op) || + isTrivialEmissionCastPtr(op); + } + + static SmallVector collectBufferOperands(Operation *op) { + SmallVector bufferOperands; + for (OpOperand &operand : op->getOpOperands()) { + if (isBufferLikeValue(operand.get().getType())) + bufferOperands.push_back(&operand); + } + return bufferOperands; + } + +private: + static bool isPTOp(Operation *op) { + return op && op->getName().getStringRef().starts_with("pto."); + } + + static StringRef getPTOpMnemonic(Operation *op) { + if (!isPTOp(op)) + return {}; + StringRef mnemonic = op->getName().getStringRef(); + (void)mnemonic.consume_front("pto."); + return mnemonic; + } + + static bool isTrivialEmissionCastPtr(Operation *op) { + auto castOp = dyn_cast_or_null(op); + return castOp && + castOp.getInput().getType() == castOp.getResult().getType(); + } + + ModuleOp module; +}; + +class VPTOLegalityValidator { +public: + VPTOLegalityValidator(ModuleOp module, VPTOLegalityStage stage, + llvm::raw_ostream *diagOS) + : helper(module), stage(stage), diagOS(diagOS) {} + + LogicalResult validate() { + if (!helper.getModule()) { + writeDiagnostic("VPTO legality validation requires a valid module\n"); + return failure(); + } + + if (failed(validateAuthoringRules())) + return failure(); + + if (stage == VPTOLegalityStage::Emission && + failed(validateEmissionRules())) + return failure(); + + return success(); + } + +private: + LogicalResult validateAuthoringRules() { + if (failed(validateAuthoringFunctionSurface())) + return failure(); + if (failed(validateAuthoringOperationSurface())) + return failure(); + return success(); + } + + LogicalResult validateEmissionRules() { + if (failed(validateEmissionFunctionSurface())) + return failure(); + if (failed(validateEmissionOperationSurface())) + return failure(); + return success(); + } + + static std::string formatExpectedMaskType(VPTOMaskGranularity granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" + << VPTOLegalityHelper::stringifyMaskGranularity(granularity) << ">"; + return storage; + } + + static LogicalResult validateMaskMatchesVectorFamily(Operation *op, + Type maskType, + StringRef maskRole, + Type vectorType, + StringRef vectorRole) { + auto actual = VPTOLegalityHelper::getMaskGranularity(maskType); + auto expected = VPTOLegalityHelper::inferMaskGranularityFromType(vectorType); + if (!actual || !expected || *actual == *expected) + return success(); + + return op->emitOpError() + << maskRole << " " << maskType << " does not match " << vectorRole + << " " << vectorType << "; expected " + << formatExpectedMaskType(*expected); + } + + static LogicalResult validateSameMaskGranularity(Operation *op, Type lhsType, + StringRef lhsRole, + Type rhsType, + StringRef rhsRole) { + auto lhs = VPTOLegalityHelper::getMaskGranularity(lhsType); + auto rhs = VPTOLegalityHelper::getMaskGranularity(rhsType); + if (!lhs || !rhs || *lhs == *rhs) + return success(); + + return op->emitOpError() << lhsRole << " " << lhsType << " does not match " + << rhsRole << " " << rhsType; + } + + template + static LogicalResult validateInputMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getInput().getType(), + "input vector type"); + } + + template + static LogicalResult validateBinaryMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getLhs().getType(), + "lhs vector type"); + } + + template + static LogicalResult validateValueMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getValue().getType(), + "value vector type"); + } + + template + static LogicalResult validateResultMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getResult().getType(), + "result vector type"); + } + + template + static LogicalResult validateCarryFamilyContract(CarryOp op) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getLhs().getType(), + "lhs vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "mask type", + op.getCarry().getType(), + "carry type"))) + return failure(); + + if constexpr (std::is_same_v || + std::is_same_v) { + if (failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getCarry().getType(), + "carry type"))) + return failure(); + } + + return success(); + } + + template + static LogicalResult validateCompareFamilyContract(CompareOp op, Type vecType) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "seed mask type", vecType, + "input vector type")) || + failed(validateMaskMatchesVectorFamily(op, op.getResult().getType(), + "result mask type", vecType, + "input vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "seed mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validateMaskOnlyUnaryContract(MaskUnaryOp op) { + return validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"); + } + + static LogicalResult validateMaskOnlyPnotContract(PnotOp op) { + if (failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + static LogicalResult validateMaskOnlyPselContract(PselOp op) { + if (failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getSrc1().getType(), + "src1 mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validatePredicateMovementContract( + PredicateMovementOp op) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + if (!expected) + return success(); + + if (failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getRhs().getType(), + "rhs mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getLow().getType(), + "low mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getHigh().getType(), + "high mask type"))) + return failure(); + + auto lhs = VPTOLegalityHelper::getMaskGranularity(op.getLhs().getType()); + if (!lhs || *lhs == *expected) + return success(); + + return op.emitOpError() + << "predicate movement family requires " + << formatExpectedMaskType(*expected) + << " but got lhs mask type " << op.getLhs().getType(); + } + + static LogicalResult validateFamilySuffixMaskResult(Operation *op, + Type resultType, + StringRef resultRole) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + auto actual = VPTOLegalityHelper::getMaskGranularity(resultType); + if (!expected || !actual || *expected == *actual) + return success(); + + return op->emitOpError() + << "family suffix requires " << resultRole << " to be " + << formatExpectedMaskType(*expected) << ", but got " << resultType; + } + + static LogicalResult validateFamilySuffixMaskContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateFamilySuffixMaskResult( + concreteOp, concreteOp.getResult().getType(), "result type"); + }) + .Case([](auto concreteOp) { + return validateFamilySuffixMaskResult(concreteOp, + concreteOp.getMask().getType(), + "mask result type"); + }) + .Default([](Operation *) { return success(); }); + } + + static LogicalResult validateMaskGranularityContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateInputMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateBinaryMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateCarryFamilyContract(concreteOp); + }) + .Case([](VcmpOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc0().getType()); + }) + .Case([](VcmpsOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc().getType()); + }) + .Case([](auto concreteOp) { + return validateMaskOnlyUnaryContract(concreteOp); + }) + .Case( + [](PnotOp concreteOp) { return validateMaskOnlyPnotContract(concreteOp); }) + .Case( + [](PselOp concreteOp) { return validateMaskOnlyPselContract(concreteOp); }) + .Case([](auto concreteOp) { + return validatePredicateMovementContract(concreteOp); + }) + .Case([](VselOp concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getSrc0().getType(), + "src0 vector type"); + }) + .Case([](auto concreteOp) { + return validateResultMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateValueMaskVectorConsumer(concreteOp); + }) + .Case([](Vstsx2Op concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLow().getType(), + "low vector type"); + }) + .Case([](auto concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLhs().getType(), + "lhs vector type"); + }) + .Default([](Operation *) { return success(); }); + } + + LogicalResult validateAuthoringFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + (void)func; + } + return success(); + } + + LogicalResult validateAuthoringOperationSurface() { + WalkResult loopWalkResult = helper.getModule().walk([&](scf::ForOp loop) { + if (!VPTOLegalityHelper::isAIVectorScopeCarrier(loop)) + return WalkResult::advance(); + + Operation *parentScope = + VPTOLegalityHelper::getEnclosingVectorScopeCarrier(loop); + if (!parentScope) + return WalkResult::advance(); + + if (isa(parentScope)) { + loop.emitOpError() << "does not allow nested scf.for with '" + << kAIVectorScopeAttrName << "'"; + return WalkResult::interrupt(); + } + + loop.emitOpError() + << "does not allow legacy scf.for carrier nested inside dedicated " + "pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (loopWalkResult.wasInterrupted()) + return failure(); + + WalkResult vecScopeWalkResult = helper.getModule().walk([&](Operation *op) { + if (!VPTOLegalityHelper::isDedicatedVecScopeCarrier(op)) + return WalkResult::advance(); + + if (!VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) + return WalkResult::advance(); + + op->emitOpError() + << "does not allow nested dedicated pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (vecScopeWalkResult.wasInterrupted()) + return failure(); + + WalkResult opWalkResult = helper.getModule().walk([&](Operation *op) { + (void)VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + (void)VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (!VPTOLegalityHelper::requiresVecScope(op)) + return WalkResult::advance(); + + if (VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) + return (failed(validateFamilySuffixMaskContracts(op)) || + failed(validateMaskGranularityContracts(op))) + ? WalkResult::interrupt() + : WalkResult::advance(); + + op->emitOpError() + << "requires enclosing scf.for with '" + << kAIVectorScopeAttrName + << "' or dedicated pto.vecscope/pto.strict_vecscope" + << "' because it consumes or produces !pto.vreg/!pto.mask/!pto.align"; + return WalkResult::interrupt(); + }); + return opWalkResult.wasInterrupted() ? failure() : success(); + } + + LogicalResult validateEmissionFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + FunctionType functionType = func.getFunctionType(); + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + if (!isa(inputType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref argument #" + << idx << ": " << inputType; + } + + for (auto [idx, resultType] : llvm::enumerate(functionType.getResults())) { + if (!isa(resultType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref result #" + << idx << ": " << resultType; + } + } + return success(); + } + + LogicalResult validateEmissionOperationSurface() { + WalkResult walkResult = helper.getModule().walk([&](Operation *op) { + VPTOBufferAddressFamily family = + VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (family == VPTOBufferAddressFamily::BufferLike) { + for (OpOperand *operand : VPTOLegalityHelper::collectBufferOperands(op)) { + Type operandType = operand->get().getType(); + if (!isa(operandType)) + continue; + + op->emitOpError() + << "emission-stage VPTO legality rejects memref-form buffer " + "operand #" + << operand->getOperandNumber() << " of type " << operandType + << " for buffer-like family op"; + return WalkResult::interrupt(); + } + } + + if (VPTOLegalityHelper::isResidualEmissionScaffold(op)) { + op->emitOpError() + << "must be eliminated before emission-stage VPTO validation"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return walkResult.wasInterrupted() ? failure() : success(); + } + + void writeDiagnostic(StringRef message) const { + if (diagOS) + *diagOS << message; + } + + VPTOLegalityHelper helper; + VPTOLegalityStage stage; + llvm::raw_ostream *diagOS; +}; + +} // namespace detail + +namespace { + +struct PTOValidateVPTOIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOIRPass) + + StringRef getArgument() const final { return "pto-validate-vpto-ir"; } + + StringRef getDescription() const final { + return "Validate authoring-stage VPTO legality before emission-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOAuthoringIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVPTOEmissionIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOEmissionIRPass) + + StringRef getArgument() const final { + return "pto-validate-vpto-emission-ir"; + } + + StringRef getDescription() const final { + return "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOEmissionIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Authoring, diagOS) + .validate(); +} + +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Emission, diagOS) + .validate(); +} + +std::unique_ptr createPTOValidateVPTOIRPass() { + return std::make_unique(); +} + +std::unique_ptr createPTOValidateVPTOEmissionIRPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 9ff679d52..1d1950dd7 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -398,7 +398,7 @@ static Type convertPTOTypeToMemRef(Type t) { // 1. 处理 !pto.ptr if (auto pty = dyn_cast(t)) { return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); + MemRefLayoutAttrInterface(), pty.getMemorySpace()); } // 2. 处理 !pto.tile_buf<...> diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp new file mode 100644 index 000000000..16c6a8bd7 --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -0,0 +1,4930 @@ +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/LLVMContext.h" + +namespace mlir::pto { + +void materializeVecScopeCarrierLoops(ModuleOp module); +LogicalResult normalizePtoMemRefSpaces(ModuleOp module, + llvm::raw_ostream &diagOS); +LogicalResult applyQueriedTargetAttrs(ModuleOp module, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS); +void attachHIVMKernelAnnotations(llvm::Module &llvmModule); + +namespace { + +static std::string getElementTypeFragment(Type type); +static Type getElementTypeFromVectorLike(Type type); +static std::optional getElementCountFromVectorLike(Type type); + +static Type convertVPTOType(Type type, Builder &builder) { + if (auto vecType = dyn_cast(type)) + return VectorType::get({vecType.getElementCount()}, vecType.getElementType()); + if (isa(type)) + return VectorType::get({256}, builder.getI1Type()); + if (isa(type)) + return VectorType::get({32}, builder.getI8Type()); + if (auto ptrType = dyn_cast(type)) { + return LLVM::LLVMPointerType::get( + builder.getContext(), + static_cast(ptrType.getMemorySpace().getAddressSpace())); + } + return type; +} + +static bool hasVPTOConvertibleType(Type type) { + return isa(type); +} + +static bool hasVPTOConvertibleType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return hasVPTOConvertibleType(type); }); +} + +static Value materializeVPTOCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); +} + +class VPTOTypeConverter final : public TypeConverter { +public: + explicit VPTOTypeConverter(MLIRContext *context) { + addConversion([](Type type) { return type; }); + addConversion([](Type type) -> Type { + // The conversion callback outlives this constructor, so build on demand + // from the current type context instead of capturing a local Builder. + Builder builder(type.getContext()); + return convertVPTOType(type, builder); + }); + addSourceMaterialization(materializeVPTOCast); + addTargetMaterialization(materializeVPTOCast); + addArgumentMaterialization(materializeVPTOCast); + } +}; + +struct PlannedDecl { + std::string name; + FunctionType type; +}; + +struct LoweringState { + SmallVector plannedDecls; +}; + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + const char *intrinsic; + bool requiresRnd; + bool requiresSat; + bool requiresPart; + unsigned maskBitWidth; +}; + +static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI64IntegerAttr(value)) + .getResult(); +} + +static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI32IntegerAttr(value)) + .getResult(); +} + +static FailureOr buildLaneTypedCallee(MLIRContext *context, + Type resultType, + StringRef stem, + StringRef suffix) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec + + suffix.str()) + .getValue(); +} + +static std::string getElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static std::string getVbrScalarFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static Type getElementTypeFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return {}; +} + +static std::optional getElementCountFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementCount(); + if (auto vecType = dyn_cast(type)) { + if (vecType.getRank() != 1) + return std::nullopt; + return vecType.getShape().front(); + } + return std::nullopt; +} + +static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + if (value.getType() == targetType) + return value; + + auto targetInt = dyn_cast(targetType); + if (value.getType().isIndex() && targetInt) + return builder.create(anchor->getLoc(), targetType, value); + if (auto sourceInt = dyn_cast(value.getType())) { + if (targetInt) { + if (sourceInt.getWidth() < targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + if (sourceInt.getWidth() > targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + return value; + } + if (targetType.isIndex()) + return builder.create(anchor->getLoc(), targetType, value); + } + + return {}; +} + +static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, + pto::VdupOp op) { + Value input = op.getInput(); + auto intType = dyn_cast(input.getType()); + if (!intType || intType.getWidth() != 8) + return input; + + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + std::string resultElemFragment = getElementTypeFragment(resultElemType); + if (resultElemFragment != "s8" && resultElemFragment != "u8") + return input; + + Type i16Type = builder.getIntegerType(16); + if (resultElemFragment == "u8") + return builder.create(loc, i16Type, input).getResult(); + return builder.create(loc, i16Type, input).getResult(); +} + +static std::string getCopyElementFragment(Type elementType) { + if (!elementType) + return {}; + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isF32()) + return "f32"; + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? "u8" : "s8"; + case 16: + return intType.isUnsigned() ? "u16" : "s16"; + case 32: + return intType.isUnsigned() ? "u32" : "s32"; + default: + return {}; + } + } + return {}; +} + +static std::optional parsePredicatePatternImmediate(StringRef pattern) { + if (pattern == "PAT_ALL") + return 0; + if (pattern == "PAT_VL1") + return 1; + if (pattern == "PAT_VL2") + return 2; + if (pattern == "PAT_VL3") + return 3; + if (pattern == "PAT_VL4") + return 4; + if (pattern == "PAT_VL8") + return 5; + if (pattern == "PAT_VL16") + return 6; + if (pattern == "PAT_VL32") + return 7; + if (pattern == "PAT_VL64") + return 8; + if (pattern == "PAT_VL128") + return 9; + if (pattern == "PAT_M3") + return 10; + if (pattern == "PAT_M4") + return 11; + if (pattern == "PAT_H") + return 12; + if (pattern == "PAT_Q") + return 13; + if (pattern == "PAT_ALLF") + return 15; + return std::nullopt; +} + +static std::optional parseHiLoPartImmediate(StringRef part) { + if (part == "LOWER") + return 0; + if (part == "HIGHER") + return 1; + return std::nullopt; +} + +static std::optional parseRoundModeImmediate(StringRef roundMode) { + if (roundMode == "R" || roundMode == "ROUND_R") + return 0; + if (roundMode == "A" || roundMode == "ROUND_A") + return 1; + if (roundMode == "F" || roundMode == "ROUND_F") + return 2; + if (roundMode == "C" || roundMode == "ROUND_C") + return 3; + if (roundMode == "Z" || roundMode == "ROUND_Z") + return 4; + if (roundMode == "O" || roundMode == "ROUND_O") + return 5; + return std::nullopt; +} + +static std::optional parseSaturationImmediate(StringRef sat) { + if (sat == "SAT" || sat == "RS_ENABLE") + return 0; + if (sat == "NOSAT" || sat == "RS_DISABLE") + return 1; + return std::nullopt; +} + +static std::optional parsePartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN") + return 0; + if (part == "ODD" || part == "PART_ODD") + return 1; + return std::nullopt; +} + +static std::optional parsePredicateStoreDistImmediate(StringRef dist) { + if (dist == "NORM") + return 0; + if (dist == "PK") + return 1; + return std::nullopt; +} + +static std::optional parsePredicateLoadDistImmediate(StringRef dist) { + if (dist.empty() || dist == "NORM") + return 0; + if (dist == "US") + return 1; + if (dist == "DS") + return 2; + return std::nullopt; +} + +static std::optional parsePostModeImmediate(StringRef mode) { + if (mode == "NO_POST_UPDATE") + return 0; + if (mode == "POST_UPDATE") + return 1; + return std::nullopt; +} + +static std::optional parsePipeImmediate(StringRef pipe) { + if (pipe == "PIPE_S") + return 0; + if (pipe == "PIPE_V") + return 1; + if (pipe == "PIPE_M") + return 2; + if (pipe == "PIPE_MTE1") + return 3; + if (pipe == "PIPE_MTE2") + return 4; + if (pipe == "PIPE_MTE3") + return 5; + if (pipe == "PIPE_ALL") + return 6; + if (pipe == "PIPE_MTE4") + return 7; + if (pipe == "PIPE_MTE5") + return 8; + if (pipe == "PIPE_V2") + return 9; + if (pipe == "PIPE_FIX") + return 10; + if (pipe == "VIRTUAL_PIPE_MTE2_L1A") + return 11; + if (pipe == "VIRTUAL_PIPE_MTE2_L1B") + return 12; + return std::nullopt; +} + +static std::optional parseEventImmediate(StringRef event) { + if (!event.consume_front("EVENT_ID")) + return std::nullopt; + uint64_t value = 0; + if (event.getAsInteger(10, value)) + return std::nullopt; + return value; +} + +static std::optional parseSprImmediate(StringRef spr) { + if (spr == "AR") + return 74; + return std::nullopt; +} + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.f322f16.x", true, true, true, 32}; + case VcvtElemKind::BF16: + return VcvtContract{"llvm.hivm.vcvtff.f322bf16.x", true, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f322s16.x", true, true, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f322s32.x", true, true, false, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtfi.f322s64.x", true, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.f162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f162s32.x", true, false, true, 16}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f162s16.x", true, true, false, 16}; + case VcvtElemKind::S8: + return VcvtContract{"llvm.hivm.vcvtfi.f162s8.x", true, true, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtfi.f162u8.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.bf162s32.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.u82f16.x", false, false, true, 8}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u82u16.x", false, false, true, 8}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u82u32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s82f16.x", false, false, true, 8}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s82s16.x", false, false, true, 8}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s82s32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u162u32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s162f16.x", true, false, false, 16}; + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s162f32.x", false, false, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.s162u32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s162s32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.u322s16.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s322f32.x", true, false, false, 32}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.s322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s322s16.x", false, true, true, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtii.s322s64.x", false, false, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s642f32.x", true, false, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s642s32.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +// VSQZ #st hint must only be set when the compacted vector feeds VSTUR. +// Emitting #st=1 without a matching VSTUR consumer can deadlock hardware queues. +static uint64_t determineVsqzStoreHint(pto::VsqzOp vsqz) { + Value result = vsqz.getResult(); + for (Operation *user : result.getUsers()) { + auto vstur = dyn_cast(user); + if (!vstur) + continue; + if (vstur.getValue() == result) + return 1; + } + return 0; +} + +static std::optional parseLoadDistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty() || dist == "NORM") + return 0; + if (!width) + return std::nullopt; + if (dist == "BRC") + return *width == 8 ? std::optional(1) + : *width == 16 ? std::optional(2) + : *width == 32 ? std::optional(3) + : std::nullopt; + if (dist == "US") + return *width == 8 ? std::optional(6) + : *width == 16 ? std::optional(7) + : std::nullopt; + if (dist == "DS") + return *width == 8 ? std::optional(8) + : *width == 16 ? std::optional(9) + : std::nullopt; + if (dist == "UNPK") + return *width == 8 ? std::optional(13) + : *width == 16 ? std::optional(14) + : *width == 32 ? std::optional(18) + : std::nullopt; + if (dist == "BRC_BLK") + return 15; + if (dist == "E2B") + return *width == 16 ? std::optional(16) + : *width == 32 ? std::optional(17) + : std::nullopt; + if (dist == "UNPK4") + return *width == 8 ? std::optional(20) : std::nullopt; + if (dist == "SPLT4CHN") + return *width == 8 ? std::optional(21) : std::nullopt; + if (dist == "SPLT2CHN") + return *width == 8 ? std::optional(22) + : *width == 16 ? std::optional(23) + : std::nullopt; + return std::nullopt; +} + +static std::optional parseLoadX2DistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist == "BDINTLV") + return 10; + if (!width) + return std::nullopt; + if (dist == "DINTLV") + return *width == 8 ? std::optional(11) + : *width == 16 ? std::optional(12) + : *width == 32 ? std::optional(19) + : std::nullopt; + return std::nullopt; +} + +static std::optional parseStoreDistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty() || dist == "NORM") { + if (!width) + return std::nullopt; + if (*width == 8) + return 0; + if (*width == 16) + return 1; + if (*width == 32) + return 2; + return std::nullopt; + } + if (!width) + return std::nullopt; + if (dist == "1PT") + return *width == 8 ? std::optional(3) + : *width == 16 ? std::optional(4) + : *width == 32 ? std::optional(5) + : std::nullopt; + if (dist == "PK") + return *width == 16 ? std::optional(6) + : *width == 32 ? std::optional(7) + : *width == 64 ? std::optional(10) + : std::nullopt; + if (dist == "PK4") + return *width == 32 ? std::optional(12) : std::nullopt; + if (dist == "MRG4CHN") + return *width == 8 ? std::optional(13) : std::nullopt; + if (dist == "MRG2CHN") + return *width == 8 ? std::optional(14) + : *width == 16 ? std::optional(15) + : std::nullopt; + return std::nullopt; +} + +static std::optional parseStoreX2DistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return std::nullopt; + if (dist == "INTLV") + return *width == 8 ? std::optional(8) + : *width == 16 ? std::optional(9) + : *width == 32 ? std::optional(11) + : std::nullopt; + return std::nullopt; +} + +static Value packBlockRepeatStride(Operation *anchor, Value blockStride, + Value repeatStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value blockI32 = castIntegerLikeTo(anchor, blockStride, builder.getI32Type()); + Value repeatI32 = + castIntegerLikeTo(anchor, repeatStride, builder.getI32Type()); + if (!blockI32 || !repeatI32) + return {}; + + auto c16 = builder.create(anchor->getLoc(), 16, 32); + auto blockShifted = + builder.create(anchor->getLoc(), blockI32, c16); + return builder + .create(anchor->getLoc(), blockShifted, repeatI32) + .getResult(); +} + +static std::optional parseOrderImmediate(StringRef order) { + if (order.empty() || order == "ASC") + return 0; + if (order == "DESC") + return 1; + return std::nullopt; +} + +static FailureOr packLoopPair(Operation *anchor, Value low, Value high) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value lowI64 = castIntegerLikeTo(anchor, low, builder.getI64Type()); + Value highI64 = castIntegerLikeTo(anchor, high, builder.getI64Type()); + if (!lowI64 || !highI64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 40); + Value highShifted = + builder.create(anchor->getLoc(), highI64, shift).getResult(); + return builder.create(anchor->getLoc(), highShifted, lowI64) + .getResult(); +} + +static FailureOr packLoopSize(Operation *anchor, Value loop2, Value loop1) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value loop2I64 = castIntegerLikeTo(anchor, loop2, builder.getI64Type()); + Value loop1I64 = castIntegerLikeTo(anchor, loop1, builder.getI64Type()); + if (!loop2I64 || !loop1I64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 21); + Value loop2Shifted = + builder.create(anchor->getLoc(), loop2I64, shift).getResult(); + return builder.create(anchor->getLoc(), loop2Shifted, loop1I64) + .getResult(); +} + +static FailureOr +packCopyGmToUbConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value leftPadding = getI64Operand(5); + Value rightPadding = getI64Operand(6); + Value dataSelect = castIntegerLikeTo(anchor, operands[7], builder.getI64Type()); + Value cacheCtl = getI64Operand(8); + if (!sid || !nBurst || !lenBurst || !leftPadding || !rightPadding || + !dataSelect || !cacheCtl) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(leftPadding, 46)); + config = bitOr(config, shl(rightPadding, 52)); + config = bitOr(config, shl(dataSelect, 58)); + config = bitOr(config, shl(cacheCtl, 60)); + return config; +} + +static FailureOr +packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + return packLoopPair(anchor, operands[9], operands[10]); +} + +static FailureOr +packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value reserved = getI64Operand(5); + if (!sid || !nBurst || !lenBurst || !reserved) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(reserved, 60)); + return config; +} + +static FailureOr +packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + return packLoopPair(anchor, operands[6], operands[7]); +} + +static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); + if (!repeatI64) + return failure(); + return builder + .create(loc, repeatI64, getI64Constant(builder, loc, 56)) + .getResult(); +} + +static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, + Type elementType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); + if (!offsetI32) + return failure(); + + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 0 || bitWidth % 8 != 0) + return failure(); + + Value scale = builder.create( + anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); + return builder.create(anchor->getLoc(), offsetI32, scale) + .getResult(); +} + +static FailureOr materializeDynamicPltMask(ConversionPatternRewriter &rewriter, + LoweringState &state, + Location loc, + Value laneCount, + Type vectorElemType) { + Type i32Type = rewriter.getI32Type(); + Value laneCountI32 = laneCount; + if (laneCountI32.getType() != i32Type) { + laneCountI32 = castIntegerLikeTo(rewriter.getInsertionBlock()->getParentOp(), + laneCountI32, i32Type); + if (!laneCountI32) + return failure(); + } + + StringRef calleeName; + if (vectorElemType.isF32()) { + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + } else if (auto intType = dyn_cast(vectorElemType)) { + if (intType.getWidth() == 32) + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + else if (intType.getWidth() == 16) + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + else if (intType.getWidth() == 8) + calleeName = StringRef("llvm.hivm.plt.b8.v300"); + } + if (calleeName.empty()) + return failure(); + + Type maskType = VectorType::get({256}, rewriter.getI1Type()); + auto funcType = + rewriter.getFunctionType(TypeRange{i32Type}, TypeRange{maskType, i32Type}); + auto call = rewriter.create(loc, calleeName, funcType.getResults(), + ValueRange{laneCountI32}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + return call.getResult(0); +} + +static FailureOr buildCarryBinaryCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +template +static StringRef getUnaryMaskedStem() { + if constexpr (std::is_same_v) + return "vabs"; + if constexpr (std::is_same_v) + return "vexp"; + if constexpr (std::is_same_v) + return "vln"; + if constexpr (std::is_same_v) + return "vneg"; + if constexpr (std::is_same_v) + return "vsqrt"; + if constexpr (std::is_same_v) + return "vrelu"; + if constexpr (std::is_same_v) + return "vnot"; + return {}; +} + +template +static StringRef getBinaryMaskedStem() { + if constexpr (std::is_same_v) + return "vadd"; + if constexpr (std::is_same_v) + return "vsub"; + if constexpr (std::is_same_v) + return "vmul"; + if constexpr (std::is_same_v) + return "vdiv"; + if constexpr (std::is_same_v) + return "vmax"; + if constexpr (std::is_same_v) + return "vmin"; + if constexpr (std::is_same_v) + return "vand"; + if constexpr (std::is_same_v) + return "vor"; + if constexpr (std::is_same_v) + return "vxor"; + if constexpr (std::is_same_v) + return "vshl"; + if constexpr (std::is_same_v) + return "vshr"; + return {}; +} + +template +static StringRef getCarryBinaryStem() { + if constexpr (std::is_same_v) + return "vaddc"; + if constexpr (std::is_same_v) + return "vsubc"; + if constexpr (std::is_same_v) + return "vaddcs"; + if constexpr (std::is_same_v) + return "vsubcs"; + return {}; +} + +template +static constexpr bool hasCarryInput() { + return std::is_same_v || + std::is_same_v; +} + +static FailureOr buildVselCallee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vsel.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVselrCallee(MLIRContext *context, + Type resultType) { + Type elemType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elemType || !lanes) + return failure(); + + std::string vec = getElementTypeFragment(elemType); + if (auto floatType = dyn_cast(elemType); + floatType && floatType.isF32()) + vec = "u32"; + if (vec.empty()) + return failure(); + + return StringAttr::get(context, "llvm.hivm.vselr.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op) { + Type inputType = op.getInput().getType(); + Type resultType = op.getResult().getType(); + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + + if (isa(inputType)) { + StringRef position = op.getPosition().value_or("LOWEST"); + StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; + return StringAttr::get(context, "llvm.hivm." + family.str() + ".v" + + std::to_string(*lanes) + vec + ".z") + .getValue(); + } + + return StringAttr::get(context, "llvm.hivm.vdups.v" + std::to_string(*lanes) + + vec + ".z") + .getValue(); +} + +static FailureOr buildVbrCallee(MLIRContext *context, Type scalarType) { + std::string scalar = getVbrScalarFragment(scalarType); + if (scalar.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); +} + +static FailureOr buildPstuCallee(MLIRContext *context, pto::PstuOp op) { + if (auto maskType = dyn_cast(op.getValue().getType())) { + if (maskType.isB16()) + return StringAttr::get(context, "llvm.hivm.pstu.b16").getValue(); + if (maskType.isB32()) + return StringAttr::get(context, "llvm.hivm.pstu.b32").getValue(); + } + return failure(); +} + +static StringRef buildVstusCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstus").getValue(); +} + +static StringRef buildVsturCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstur").getValue(); +} + +static StringRef buildInitAlignCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); +} + +static StringRef buildSprclrCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); +} + +static StringRef buildVstarCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstar").getValue(); +} + +static StringRef buildVstasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstas").getValue(); +} + +static FailureOr buildVldsPostCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVstsPostCallee(MLIRContext *context, + Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static StringRef buildVldasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vldas").getValue(); +} + +static FailureOr buildVldusCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldus.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, + StringRef cmpMode, + bool isScalarCompare) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + if (elem.empty()) + return failure(); + StringRef stem = isScalarCompare ? "vcmps" : "vcmp"; + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + cmpMode.str() + "." + elem + ".z") + .getValue(); +} + +template +static StringRef getVecScalarMaskedStem() { + if constexpr (std::is_same_v) + return "vmuls"; + if constexpr (std::is_same_v) + return "vadds"; + if constexpr (std::is_same_v) + return "vmaxs"; + if constexpr (std::is_same_v) + return "vmins"; + if constexpr (std::is_same_v) + return "vlrelu"; + if constexpr (std::is_same_v) + return "vshls"; + if constexpr (std::is_same_v) + return "vshrs"; + return {}; +} + +template +static StringRef getReductionUnaryStem() { + if constexpr (std::is_same_v) + return "vcadd"; + if constexpr (std::is_same_v) + return "vcmax"; + if constexpr (std::is_same_v) + return "vcmin"; + if constexpr (std::is_same_v) + return "vcgadd"; + if constexpr (std::is_same_v) + return "vcgmax"; + if constexpr (std::is_same_v) + return "vcgmin"; + if constexpr (std::is_same_v) + return "vcpadd"; + return {}; +} + +static FailureOr buildCopyGmToUbCallee(MLIRContext *context, + pto::CopyGmToUbufOp op) { + Type elementType = cast(op.getSource().getType()).getElementType(); + std::string elem = getCopyElementFragment(elementType); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + + ".DV") + .getValue(); +} + +static StringRef buildCopyUbToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV") + .getValue(); +} + +static StringRef buildPstiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); +} + +static StringRef buildPstsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psts.b8").getValue(); +} + +static StringRef buildPldiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pldi.b8").getValue(); +} + +static StringRef buildPldsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); +} + +static StringRef buildPnotCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); +} + +static StringRef buildPselCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psel").getValue(); +} + +static StringRef buildPandCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pand.z").getValue(); +} + +static StringRef buildPorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.por.z").getValue(); +} + +static StringRef buildPxorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pxor.z").getValue(); +} + +static StringRef buildPpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.ppack.z").getValue(); +} + +static StringRef buildPunpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.punpack").getValue(); +} + +template +static StringRef buildPredicatePairReorderCallee(MLIRContext *context); + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b8").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b16").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b32").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b8").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b16").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b32").getValue(); +} + +static FailureOr buildInterleaveCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + return buildLaneTypedCallee(context, resultType, stem, ""); +} + +static FailureOr buildUnpackCallee(MLIRContext *context, + Type inputType, + Type resultType, + StringRef stem) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + return StringAttr::get(context, + "llvm.hivm." + stem.str() + "." + input + "2" + result) + .getValue(); +} + +static FailureOr buildVpackCallee(MLIRContext *context, Type inputType, + Type resultType) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + + return StringAttr::get(context, "llvm.hivm.vpack." + input + "2" + result + ".x") + .getValue(); +} + +static FailureOr buildVsqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vsqz", ".x.v300"); +} + +static FailureOr buildVusqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vusqz", ".m"); +} + +static FailureOr buildVmulaCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmula", ".m"); +} + +static FailureOr buildVmullCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmull", ""); +} + +template +static StringRef getPredicateStoreCallee(MLIRContext *context); + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstiCallee(context); +} + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstsCallee(context); +} + +template +static StringRef getPredicateLoadCallee(MLIRContext *context); + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldiCallee(context); +} + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldsCallee(context); +} + +template +static StringRef getPredicateMaskCallee(MLIRContext *context); + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPnotCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPselCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPandCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPorCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPxorCallee(context); +} + +template +static StringRef getPredicatePackCallee(MLIRContext *context); + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPpackCallee(context); +} + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPunpackCallee(context); +} + +template +static StringRef buildPltCallee(MLIRContext *context); + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b8.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b16.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b32.v300").getValue(); +} + +template +static StringRef buildPsetCallee(MLIRContext *context); + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b8").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b16").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b32").getValue(); +} + +template +static StringRef buildPgeCallee(MLIRContext *context); + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b8").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b16").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b32").getValue(); +} + +static FailureOr buildVldsCallee(MLIRContext *context, Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVldsx2Callee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vldsx2", ""); +} + +static StringRef buildVsldbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsldb").getValue(); +} + +static FailureOr buildVstsCallee(MLIRContext *context, Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVstsx2Callee(MLIRContext *context, Type valueType) { + return buildLaneTypedCallee(context, valueType, "vstsx2", ""); +} + +static StringRef buildVsstbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsstb").getValue(); +} + +static FailureOr buildVgather2Callee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vgather2.v300.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVgather2BcCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); +} + +static FailureOr buildVgatherbCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgatherb.v310", ""); +} + +static FailureOr buildVscatterCallee(MLIRContext *context, + Type valueType) { + return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); +} + +static FailureOr buildVpreluCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vprelu", ".x"); +} + +static FailureOr buildVaxpyCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); +} + +static FailureOr buildVciCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + if (vec == "f16" || vec == "f32") + return StringAttr::get(context, "llvm.hivm.vci.v" + std::to_string(*lanes) + + vec + "." + vec) + .getValue(); + return StringAttr::get(context, + "llvm.hivm.vci.v" + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVtrcCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); +} + +static FailureOr buildVexpdiffCallee(MLIRContext *context, + Type inputType, + Type resultType) { + std::string srcVec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto srcLanes = getElementCountFromVectorLike(inputType); + std::string dstElem = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (srcVec.empty() || dstElem.empty() || !srcLanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vexpdif.v" + + std::to_string(*srcLanes) + srcVec + + dstElem) + .getValue(); +} + +static FailureOr buildVbitsortCallee(MLIRContext *context, + pto::VbitsortOp op) { + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + if (sourceElemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f16").getValue(); + if (sourceElemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f32").getValue(); + return failure(); +} + +static FailureOr buildVcvtContract(pto::VcvtOp op) { + Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!inputElemType || !resultElemType) + return failure(); + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (!contract) + return failure(); + return *contract; +} + +template +static StringRef buildSetLoopCallee(MLIRContext *context); + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT") + .getValue(); +} + +template +static StringRef buildSyncCallee(MLIRContext *context); + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.RLS.BUFI.mode").getValue(); +} + +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.IDX").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKID").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.NUM").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKDIM").getValue(); +} + +static LogicalResult +materializeDecls(ModuleOp module, ArrayRef plannedDecls, + llvm::raw_ostream &diagOS) { + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(&module.getBodyRegion().front()); + for (const PlannedDecl &decl : plannedDecls) { + if (func::FuncOp existing = module.lookupSymbol(decl.name)) { + if (existing.getFunctionType() != decl.type) { + diagOS << "VPTO LLVM emission failed: conflicting declaration for " + << decl.name << "\n"; + return failure(); + } + continue; + } + auto func = + builder.create(module.getLoc(), decl.name, decl.type); + func.setPrivate(); + } + return success(); +} + +template +class LowerUnaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getUnaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported unary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unary result type"); + + Value input = adaptor.getOperands()[0]; + Value mask = adaptor.getOperands()[1]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(1).getType()); + if (!input || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVsqzOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsqzOp op, pto::VsqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVsqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsqz types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsqz operand types"); + } + + Value storeHint = + getI32Constant(rewriter, op.getLoc(), determineVsqzStoreHint(op)); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, maskType, storeHint.getType()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, mask, storeHint}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVusqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVusqzOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VusqzOp op, pto::VusqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVusqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vusqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vusqz types"); + + Value src = adaptor.getSrc(); + Value mask = adaptor.getMask(); + if (!src || !mask || src.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vusqz operand types"); + } + + auto funcType = + rewriter.getFunctionType(TypeRange{resultType, maskType}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmulaOpPattern final : public OpConversionPattern { +public: + explicit LowerVmulaOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmulaOp op, pto::VmulaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmulaCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmula VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vmula types"); + + Value acc = adaptor.getAcc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!acc || !lhs || !rhs || !mask || acc.getType() != resultType || + lhs.getType() != resultType || rhs.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmula operand types"); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, resultType, resultType, maskType}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{acc, lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmullOpPattern final : public OpConversionPattern { +public: + explicit LowerVmullOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmullOp op, pto::VmullOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmullCallee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmull VPTO signature"); + + Type inputType = this->getTypeConverter()->convertType(op.getLhs().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + SmallVector resultTypes; + if (!inputType || !maskType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert vmull types"); + } + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure(op, "unexpected converted vmull results"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!lhs || !rhs || !mask || lhs.getType() != inputType || + rhs.getType() != inputType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmull operand types"); + } + + auto funcType = rewriter.getFunctionType(TypeRange{inputType, inputType, maskType}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBinaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getBinaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert binary result type"); + + Value lhs = adaptor.getOperands()[0]; + Value rhs = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!lhs || !rhs || !mask || lhs.getType() != resultType || + rhs.getType() != resultType || mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted binary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCarryBinaryOpPattern final : public OpConversionPattern { +public: + explicit LowerCarryBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CarryOp op, typename CarryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getCarryBinaryStem(); + FailureOr calleeName = + buildCarryBinaryCallee(op.getContext(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported carry VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type carryType = + this->getTypeConverter()->convertType(op->getResult(1).getType()); + if (!resultType || !carryType) + return rewriter.notifyMatchFailure(op, + "failed to convert carry result types"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + const size_t expectedArgCount = hasCarryInput() ? 4 : 3; + if (callArgs.size() != expectedArgCount || callArgs[0].getType() != resultType || + callArgs[1].getType() != resultType || callArgs.back().getType() != carryType) + return rewriter.notifyMatchFailure(op, + "unexpected converted carry operand types"); + if constexpr (hasCarryInput()) { + if (callArgs[2].getType() != carryType) + return rewriter.notifyMatchFailure( + op, "unexpected converted carry input operand type"); + } + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, carryType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyOpPattern final : public OpConversionPattern { +public: + explicit LowerCopyOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = failure(); + if constexpr (std::is_same_v) + calleeName = buildCopyGmToUbCallee(op.getContext(), op); + else + calleeName = buildCopyUbToGmCallee(op.getContext()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported copy VPTO signature"); + + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config0 = failure(); + FailureOr config1 = failure(); + if constexpr (std::is_same_v) { + config0 = packCopyGmToUbConfig0(op, adaptor.getOperands()); + config1 = packCopyGmToUbConfig1(op, adaptor.getOperands()); + } else { + config0 = packCopyUbToGmConfig0(op, adaptor.getOperands()); + config1 = packCopyUbToGmConfig1(op, adaptor.getOperands()); + } + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config0, *config1}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerVecScalarMaskedOpPattern final + : public OpConversionPattern { +public: + explicit LowerVecScalarMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(VecScalarOp op, typename VecScalarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getVecScalarMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported vec-scalar VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert vec-scalar result type"); + + Value input = adaptor.getOperands()[0]; + Value scalar = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!input || !scalar || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vec-scalar VPTO operand types"); + } + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, scalar, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerReductionUnaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerReductionUnaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReductionOp op, typename ReductionOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getReductionUnaryStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported reduction VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) { + return rewriter.notifyMatchFailure( + op, "failed to convert reduction result type"); + } + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted reduction operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselOpPattern final : public OpConversionPattern { +public: + explicit LowerVselOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselOp op, pto::VselOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsel VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsel result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsel operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVdupOpPattern final : public OpConversionPattern { +public: + explicit LowerVdupOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VdupOp op, pto::VdupOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildVdupCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vdup VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vdup result type"); + + Value mask = adaptor.getMask(); + if (!mask || mask.getType() != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vdup mask type"); + + SmallVector callArgs; + bool vectorInput = isa(op.getInput().getType()); + if (vectorInput) { + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "vector-input vdup requires matching result type"); + } + callArgs.push_back(input); + } else { + Type scalarType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!scalarType || op.getInput().getType() != scalarType) { + return rewriter.notifyMatchFailure(op, + "unexpected scalar-input vdup type"); + } + FailureOr normalizedScalar = + normalizeVdupScalarOperand(rewriter, op.getLoc(), op); + if (failed(normalizedScalar)) + return rewriter.notifyMatchFailure(op, + "failed to normalize scalar vdup input"); + callArgs.push_back(*normalizedScalar); + } + + callArgs.push_back(mask); + callArgs.push_back(getI32Constant(rewriter, op.getLoc(), 1)); + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbrOpPattern final : public OpConversionPattern { +public: + explicit LowerVbrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VbrOp op, pto::VbrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVbrCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vbr result type"); + + Value scalar = adaptor.getValue(); + if (!scalar || scalar.getType() != op.getValue().getType()) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbr operand type"); + + auto funcType = rewriter.getFunctionType(TypeRange{scalar.getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{scalar}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselrOpPattern final : public OpConversionPattern { +public: + explicit LowerVselrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselrOp op, pto::VselrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselrCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vselr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + auto resultVectorType = dyn_cast(resultType); + if (!resultVectorType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr result type"); + + Type intrinsicResultType = resultType; + if (auto floatType = dyn_cast(resultVectorType.getElementType()); + floatType && floatType.isF32()) { + intrinsicResultType = VectorType::get( + resultVectorType.getShape(), rewriter.getI32Type(), + resultVectorType.getScalableDims()); + } + + Type indexType = this->getTypeConverter()->convertType(op.getSrc1().getType()); + if (!indexType) + return rewriter.notifyMatchFailure(op, + "failed to convert vselr index type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + if (!src0 || !src1 || src1.getType() != indexType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr operand types"); + + if (src0.getType() != intrinsicResultType) { + if (src0.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr source type"); + src0 = rewriter.create(op.getLoc(), intrinsicResultType, src0); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{intrinsicResultType, indexType}, TypeRange{intrinsicResultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{intrinsicResultType}, + ValueRange{src0, src1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + + Value result = call.getResult(0); + if (intrinsicResultType != resultType) + result = rewriter.create(op.getLoc(), resultType, result); + rewriter.replaceOp(op, result); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPnotOpPattern final : public OpConversionPattern { +public: + explicit LowerPnotOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PnotOp op, pto::PnotOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert pnot result type"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted pnot operand types"); + } + + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerInterleaveOpPattern final + : public OpConversionPattern { +public: + explicit LowerInterleaveOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(InterleaveOp op, typename InterleaveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vintlv" : "vdintlv"; + FailureOr calleeName = + buildInterleaveCallee(op.getContext(), op.getLow().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported interleave VPTO signature"); + + Type lowType = this->getTypeConverter()->convertType(op.getLow().getType()); + Type highType = this->getTypeConverter()->convertType(op.getHigh().getType()); + if (!lowType || !highType || lowType != highType) { + return rewriter.notifyMatchFailure( + op, "failed to convert interleave result types"); + } + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != lowType || rhs.getType() != lowType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted interleave operand types"); + } + + auto funcType = rewriter.getFunctionType(TypeRange{lowType, lowType}, + TypeRange{lowType, highType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{lowType, highType}, ValueRange{lhs, rhs}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePackOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicatePackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PackOp op, typename PackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pack result type"); + + auto part = parseHiLoPartImmediate(op.getPart()); + if (!part) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-pack part immediate"); + + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pack operand type"); + + Value partValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*part)); + StringRef calleeName = getPredicatePackCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, rewriter.getI32Type()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, ValueRange{input, partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerUnpackOpPattern final : public OpConversionPattern { +public: + explicit LowerUnpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnpackOp op, typename UnpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vsunpack" + : "vzunpack"; + FailureOr calleeName = buildUnpackCallee( + op.getContext(), op.getSrc().getType(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported unpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unpack types"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unpack source type"); + } + + Value part = castIntegerLikeTo(op, adaptor.getPart(), rewriter.getI32Type()); + if (!part) + return rewriter.notifyMatchFailure(op, "failed to materialize unpack part"); + + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVpackOpPattern final : public OpConversionPattern { +public: + explicit LowerVpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VpackOp op, pto::VpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVpackCallee(op.getContext(), op.getSrc().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vpack types"); + + auto partImm = parseHiLoPartImmediate(op.getPart()); + if (!partImm) + return rewriter.notifyMatchFailure(op, "unsupported vpack part immediate"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vpack source type"); + } + + Value part = getI32Constant(rewriter, op.getLoc(), *partImm); + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateMaskBinaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicateMaskBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(PredicateMaskOp op, typename PredicateMaskOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-mask result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-mask operand types"); + } + + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePairReorderOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicatePairReorderOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReorderOp op, typename ReorderOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pair-reorder result types"); + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure( + op, "unexpected predicate-pair-reorder converted result types"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != resultTypes[0] || + rhs.getType() != resultTypes[0]) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pair-reorder operand types"); + } + + StringRef calleeName = + buildPredicatePairReorderCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{lhs, rhs}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCmpOpPattern final : public OpConversionPattern { +public: + explicit LowerCmpOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CmpOp op, typename CmpOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isScalarCompare = std::is_same_v; + Type inputType = Type(); + if constexpr (isScalarCompare) + inputType = op.getSrc().getType(); + else + inputType = op.getSrc0().getType(); + FailureOr calleeName = + buildVcmpCallee(op.getContext(), inputType, op.getCmpMode(), + isScalarCompare); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported compare VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = + this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, + "failed to convert compare result type"); + if (resultType != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected compare mask conversion"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + if constexpr (isScalarCompare) { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted scalar-compare operand types"); + } + } else { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[0].getType() != callArgs[1].getType() || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted compare operand types"); + } + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPltOpPattern final : public OpConversionPattern { +public: + explicit LowerPltOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PltOp op, typename PltOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value laneCount = castIntegerLikeTo(op, adaptor.getScalar(), rewriter.getI32Type()); + if (!laneCount) + return rewriter.notifyMatchFailure(op, "failed to materialize plt lane count"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert plt result types"); + + StringRef calleeName = buildPltCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{laneCount}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPsetOpPattern final : public OpConversionPattern { +public: + explicit LowerPsetOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PsetOp op, typename PsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pset pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pset result types"); + + StringRef calleeName = buildPsetCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{patternValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPgeOpPattern final : public OpConversionPattern { +public: + explicit LowerPgeOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PgeOp op, typename PgeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pge pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pge result types"); + + StringRef calleeName = buildPgeCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI32Type(), rewriter.getI32Type()}, resultTypes); + auto call = + rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{patternValue, zero}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsOpPattern final : public OpConversionPattern { +public: + explicit LowerVldsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds element type"); + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vlds operands"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert vlds result types"); + + FailureOr calleeName = buildVldsCallee(op.getContext(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, zero}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVldsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsPostOp op, pto::VldsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vlds_post operands"); + } + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type updatedSourceType = + this->getTypeConverter()->convertType(op.getUpdatedSource().getType()); + if (!resultType || !updatedSourceType || updatedSourceType != adaptor.getSource().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vlds_post result types"); + } + + FailureOr calleeName = + buildVldsPostCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, postValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), postValue.getType()}, + TypeRange{resultType, updatedSourceType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, updatedSourceType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVldsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vldsx2Op op, pto::Vldsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = parseLoadX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vldsx2 operands"); + } + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 2) { + return rewriter.notifyMatchFailure(op, + "failed to convert vldsx2 result types"); + } + + FailureOr calleeName = + buildVldsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsldbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsldbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsldbOp op, pto::VsldbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsldb operands"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vsldb result type"); + + StringRef calleeName = buildVsldbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), packedStride, zeroValue, + adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), packedStride.getType(), + zeroValue.getType(), adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerInitAlignOpPattern final + : public OpConversionPattern { +public: + explicit LowerInitAlignOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::InitAlignOp op, pto::InitAlignOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert init_align result type"); + + StringRef calleeName = buildInitAlignCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldasOpPattern final : public OpConversionPattern { +public: + explicit LowerVldasOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldasOp op, pto::VldasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!sourceType || !resultType) + return rewriter.notifyMatchFailure(op, + "expected converted vldas operand/result types"); + + StringRef calleeName = buildVldasCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getSource().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{adaptor.getSource()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldusOpPattern final : public OpConversionPattern { +public: + explicit LowerVldusOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldusOp op, pto::VldusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + SmallVector resultTypes; + if (!sourceType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes)) || + resultTypes.size() != 2 || adaptor.getAlign().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "expected converted vldus operand/result types"); + } + + FailureOr calleeName = + buildVldusCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldus signature"); + + SmallVector intrinsicResultTypes(resultTypes.begin(), resultTypes.end()); + // The installed no-post A5 vldus intrinsic returns an extra hidden base ptr. + intrinsicResultTypes.push_back(adaptor.getSource().getType()); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAlign().getType()}, + intrinsicResultTypes); + auto call = rewriter.create( + op.getLoc(), *calleeName, intrinsicResultTypes, + ValueRange{adaptor.getSource(), adaptor.getAlign()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults().take_front(resultTypes.size())); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerSprclrOpPattern final : public OpConversionPattern { +public: + explicit LowerSprclrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::SprclrOp op, pto::SprclrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto spr = parseSprImmediate(op.getSpr()); + if (!spr) + return rewriter.notifyMatchFailure(op, "unsupported sprclr target"); + + StringRef calleeName = buildSprclrCallee(op.getContext()); + Value sprValue = rewriter.create( + op.getLoc(), rewriter.getI16IntegerAttr(*spr)); + auto funcType = rewriter.getFunctionType(TypeRange{sprValue.getType()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, ValueRange{sprValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsOpPattern final : public OpConversionPattern { +public: + explicit LowerVstsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts element type"); + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getDestination().getType()); + auto dist = + parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vsts operands"); + + FailureOr calleeName = + buildVstsCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + *offsetBytes, distValue, zero, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type(), adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsstbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsstbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsstbOp op, pto::VsstbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsstb operands"); + + StringRef calleeName = buildVsstbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + packedStride, zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + packedStride.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVstsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsPostOp op, pto::VstsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = + parseStoreDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vsts_post operands"); + } + + Type updatedDestinationType = + this->getTypeConverter()->convertType(op.getUpdatedDestination().getType()); + if (!updatedDestinationType || updatedDestinationType != adaptor.getDestination().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vsts_post result type"); + } + + FailureOr calleeName = + buildVstsPostCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + distValue, postValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), distValue.getType(), postValue.getType(), + adaptor.getMask().getType()}, + TypeRange{updatedDestinationType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{updatedDestinationType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVstsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vstsx2Op op, pto::Vstsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = parseStoreX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vstsx2 operands"); + } + + FailureOr calleeName = + buildVstsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getLow(), adaptor.getHigh(), + adaptor.getDestination(), *offsetBytes, distValue, + zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getLow().getType(), adaptor.getHigh().getType(), + adaptor.getDestination().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPstuOpPattern final : public OpConversionPattern { +public: + explicit LowerPstuOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PstuOp op, pto::PstuOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildPstuCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported pstu signature"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pstu result types"); + if (resultTypes.size() != 2) + return rewriter.notifyMatchFailure(op, "unexpected converted pstu result arity"); + + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!baseType || adaptor.getAlignIn().getType() != resultTypes[0] || + adaptor.getBase().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "unexpected converted pstu operand/result types"); + } + + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstusOpPattern final : public OpConversionPattern { +public: + explicit LowerVstusOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstusOp op, pto::VstusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstus element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstus offset"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstus operand/result types"); + } + + StringRef calleeName = buildVstusCallee(op.getContext()); + SmallVector args{adaptor.getValue(), adaptor.getBase(), *offsetBytes, + adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + (*offsetBytes).getType(), adaptor.getAlignIn().getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsturOpPattern final : public OpConversionPattern { +public: + explicit LowerVsturOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsturOp op, pto::VsturOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto postMode = parsePostModeImmediate(op.getMode()); + if (!postMode) + return rewriter.notifyMatchFailure(op, "unsupported vstur mode immediate"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstur operand/result types"); + } + + StringRef calleeName = buildVsturCallee(op.getContext()); + Value modeValue = getI32Constant(rewriter, op.getLoc(), *postMode); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn(), + modeValue, zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType(), modeValue.getType(), + zeroValue.getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstarOpPattern final : public OpConversionPattern { +public: + explicit LowerVstarOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstarOp op, pto::VstarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstar operand types"); + } + + StringRef calleeName = buildVstarCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstasOpPattern final : public OpConversionPattern { +public: + explicit LowerVstasOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstasOp op, pto::VstasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + auto dstType = dyn_cast(op.getDestination().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType || !dstType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstas operand types"); + } + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), dstType.getElementType()); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstas offset"); + + StringRef calleeName = buildVstasCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2OpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2Op op, pto::Vgather2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2 operand types"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to materialize vgather2 mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vgather2 result type"); + + FailureOr calleeName = + buildVgather2Callee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2BcOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2BcOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2BcOp op, pto::Vgather2BcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2_bc operand/result types"); + + FailureOr calleeName = + buildVgather2BcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2_bc signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgatherbOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgatherbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VgatherbOp op, pto::VgatherbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgatherb operand/result types"); + + FailureOr calleeName = + buildVgatherbCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgatherb signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVscatterOpPattern final + : public OpConversionPattern { +public: + explicit LowerVscatterOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VscatterOp op, pto::VscatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getValue().getType()); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vscatter operand types"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), adaptor.getActiveLanes(), elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to materialize vscatter mask"); + + FailureOr calleeName = + buildVscatterCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vscatter signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + adaptor.getOffsets().getType(), (*mask).getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getValue(), adaptor.getDestination(), + adaptor.getOffsets(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVpreluOpPattern final : public OpConversionPattern { +public: + explicit LowerVpreluOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VpreluOp op, pto::VpreluOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!laneCount || !elemType) + return rewriter.notifyMatchFailure(op, "unsupported vprelu signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vprelu mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vprelu result type"); + + FailureOr calleeName = + buildVpreluCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vprelu callee"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getLhs().getType(), adaptor.getRhs().getType(), + (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getLhs(), adaptor.getRhs(), *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVaxpyOpPattern final : public OpConversionPattern { +public: + explicit LowerVaxpyOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VaxpyOp op, pto::VaxpyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getResult().getType()); + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!laneCount || !elemType) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vaxpy mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vaxpy result type"); + + FailureOr calleeName = + buildVaxpyCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy callee"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSrc1().getType(), adaptor.getSrc0().getType(), + adaptor.getAlpha().getType(), (*mask).getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSrc1(), adaptor.getSrc0(), adaptor.getAlpha(), + *mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVciOpPattern final : public OpConversionPattern { +public: + explicit LowerVciOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VciOp op, pto::VciOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto order = parseOrderImmediate(op.getOrder().value_or("ASC")); + if (!order) + return rewriter.notifyMatchFailure(op, "unsupported vci order"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vci result type"); + + FailureOr calleeName = + buildVciCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vci callee"); + + Value orderValue = getI32Constant(rewriter, op.getLoc(), *order); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getIndex().getType(), orderValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getIndex(), orderValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVexpdiffOpPattern final + : public OpConversionPattern { +public: + explicit LowerVexpdiffOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VexpdiffOp op, pto::VexpdiffOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto laneCount = getElementCountFromVectorLike(op.getInput().getType()); + Type elemType = getElementTypeFromVectorLike(op.getInput().getType()); + auto part = parsePartImmediate(op.getPart()); + if (!laneCount || !elemType || !part) + return rewriter.notifyMatchFailure(op, "unsupported vexpdiff signature"); + + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), getI32Constant(rewriter, op.getLoc(), *laneCount), + elemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vexpdiff mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vexpdiff result type"); + + FailureOr calleeName = + buildVexpdiffCallee(op.getContext(), op.getInput().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vexpdiff callee"); + + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), adaptor.getMax().getType(), + (*mask).getType(), partValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), adaptor.getMax(), *mask, partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbitsortOpPattern final + : public OpConversionPattern { +public: + explicit LowerVbitsortOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VbitsortOp op, pto::VbitsortOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + dyn_cast(adaptor.getDestination().getType()); + auto srcType = dyn_cast(adaptor.getSource().getType()); + auto idxType = + dyn_cast(adaptor.getIndices().getType()); + if (!dstType || !srcType || !idxType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbitsort operand types"); + + FailureOr config = packVbitsortConfig(op, adaptor.getRepeatTimes()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack vbitsort config"); + + FailureOr calleeName = buildVbitsortCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbitsort signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getDestination().getType(), adaptor.getSource().getType(), + adaptor.getIndices().getType(), (*config).getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getDestination(), adaptor.getSource(), + adaptor.getIndices(), *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVcvtOpPattern final : public OpConversionPattern { +public: + explicit LowerVcvtOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VcvtOp op, pto::VcvtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputLanes = getElementCountFromVectorLike(op.getInput().getType()); + if (!inputLanes) + return rewriter.notifyMatchFailure(op, "unsupported vcvt input shape"); + + FailureOr contract = buildVcvtContract(op); + if (failed(contract)) + return rewriter.notifyMatchFailure(op, "unsupported vcvt type pair"); + + Type maskElemType = rewriter.getIntegerType((*contract).maskBitWidth); + FailureOr mask = materializeDynamicPltMask( + rewriter, state, op.getLoc(), + getI32Constant(rewriter, op.getLoc(), *inputLanes), maskElemType); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to materialize vcvt mask"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); + + SmallVector callArgs; + SmallVector argTypes; + callArgs.push_back(adaptor.getInput()); + argTypes.push_back(adaptor.getInput().getType()); + callArgs.push_back(*mask); + argTypes.push_back((*mask).getType()); + + if ((*contract).requiresRnd) { + auto roundMode = + op.getRndAttr() ? parseRoundModeImmediate(*op.getRnd()) : std::nullopt; + if (!roundMode) + return rewriter.notifyMatchFailure(op, "vcvt requires valid rnd attr"); + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + callArgs.push_back(roundValue); + argTypes.push_back(roundValue.getType()); + } + + if ((*contract).requiresSat) { + auto saturation = + op.getSatAttr() ? parseSaturationImmediate(*op.getSat()) : std::nullopt; + if (!saturation) + return rewriter.notifyMatchFailure(op, "vcvt requires valid sat attr"); + Value satValue = getI32Constant(rewriter, op.getLoc(), *saturation); + callArgs.push_back(satValue); + argTypes.push_back(satValue.getType()); + } + + if ((*contract).requiresPart) { + auto part = op.getPartAttr() ? parsePartImmediate(*op.getPart()) : std::nullopt; + if (!part) + return rewriter.notifyMatchFailure(op, "vcvt requires valid part attr"); + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + callArgs.push_back(partValue); + argTypes.push_back(partValue.getType()); + } + + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{std::string((*contract).intrinsic), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVtrcOpPattern final : public OpConversionPattern { +public: + explicit LowerVtrcOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VtrcOp op, pto::VtrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundMode = parseRoundModeImmediate(op.getRoundMode()); + if (!roundMode) + return rewriter.notifyMatchFailure(op, "unsupported vtrc signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vtrc result type"); + + FailureOr calleeName = + buildVtrcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vtrc callee"); + + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), roundValue.getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), roundValue, adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateStoreOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateStoreOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(StoreOp op, typename StoreOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmDestType = + dyn_cast(adaptor.getDestination().getType()); + Type valueType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!llvmDestType || !valueType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-store operand types"); + + auto dist = parsePredicateStoreDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-store dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-store offset to i32"); + + StringRef calleeName = getPredicateStoreCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getValue()); + args.push_back(adaptor.getDestination()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{valueType, llvmDestType, rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateLoadOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateLoadOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getSource().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!llvmSourceType || !resultType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-load operand/result types"); + + auto dist = parsePredicateLoadDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-load dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-load offset to i32"); + + StringRef calleeName = getPredicateLoadCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getSource()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{llvmSourceType, rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, resultType, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerSetLoopConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerSetLoopConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoopOp op, typename LoopOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr packed = failure(); + if constexpr (std::is_same_v || + std::is_same_v) { + packed = packLoopSize(op, adaptor.getFirst(), adaptor.getSecond()); + } else { + packed = packLoopPair(op, adaptor.getFirst(), adaptor.getSecond()); + } + if (failed(packed)) + return rewriter.notifyMatchFailure(op, + "failed to pack loop configuration"); + + StringRef calleeName = buildSetLoopCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*packed}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPipeEventSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerPipeEventSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(SyncOp op, typename SyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto src = parsePipeImmediate(stringifyPIPE(op.getSrcPipe().getPipe())); + auto dst = parsePipeImmediate(stringifyPIPE(op.getDstPipe().getPipe())); + auto event = parseEventImmediate(stringifyEVENT(op.getEventId().getEvent())); + if (!src || !dst || !event) + return rewriter.notifyMatchFailure(op, "unsupported sync immediate"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value srcValue = getI64Constant(rewriter, op.getLoc(), *src); + Value dstValue = getI64Constant(rewriter, op.getLoc(), *dst); + Value eventValue = getI64Constant(rewriter, op.getLoc(), *event); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{srcValue, dstValue, eventValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerBarrierOpPattern final : public OpConversionPattern { +public: + explicit LowerBarrierOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pipe = parsePipeImmediate(stringifyPIPE(op.getPipe().getPipe())); + if (!pipe) + return rewriter.notifyMatchFailure(op, "unsupported barrier pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipe); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBufSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerBufSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BufSyncOp op, typename BufSyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + PIPE pipe = PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(op.getOpTypeAttr())) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure( + op, "buffer sync expects pipe/sync_op_type/pipe_event_type attr"); + pipe = mapSyncOpTypeToPipe(*opTypeOr); + } + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, + "buffer sync op_type cannot map to concrete pipe"); + + auto pipeImm = parsePipeImmediate(stringifyPIPE(pipe)); + if (!pipeImm) + return rewriter.notifyMatchFailure(op, "unsupported buffer sync pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipeImm); + Value bufIdValue = + getI64Constant(rewriter, op.getLoc(), op.getBufIdAttr().getInt()); + Value modeValue = + getI64Constant(rewriter, op.getLoc(), op.getModeAttr().getInt()); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue, bufIdValue, modeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerRuntimeQueryOpPattern final : public OpConversionPattern { +public: + explicit LowerRuntimeQueryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(QueryOp op, typename QueryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert runtime-query result type"); + + StringRef calleeName = buildRuntimeQueryCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class ConvertVPTOUnrealizedCastOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() != convertedResultType) + return failure(); + + rewriter.replaceOp(op, input); + return success(); + } +}; + +class ConvertPtoAddPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getResult().getType()); + auto llvmPtrType = dyn_cast(convertedResultType); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + auto gep = rewriter.create( + op.getLoc(), llvmPtrType, cast(op.getPtr().getType()).getElementType(), + adaptor.getPtr(), ValueRange{offset}); + rewriter.replaceOp(op, gep.getResult()); + return success(); + } +}; + +class ConvertPtoCastPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, + "could not convert castptr result type"); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + if (auto llvmPtrType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + auto sourcePtrType = dyn_cast(inputType); + if (!sourcePtrType) + return rewriter.notifyMatchFailure(op, + "expected integer or LLVM pointer input"); + if (sourcePtrType.getAddressSpace() == llvmPtrType.getAddressSpace()) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + return rewriter.notifyMatchFailure( + op, "cross-address-space ptr casts are unsupported"); + } + + if (auto resultIntType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); + } + } + + return rewriter.notifyMatchFailure(op, "unsupported castptr conversion"); + } +}; + +class ConvertPtoLoadScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + op.getValue().getType(), adaptor.getPtr(), + ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, op.getValue().getType(), elemPtr, + getNaturalAlignment(op.getValue().getType())); + return success(); + } +}; + +class ConvertPtoStoreScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + adaptor.getValue().getType(), + adaptor.getPtr(), ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, + getNaturalAlignment(adaptor.getValue().getType())); + rewriter.eraseOp(op); + return success(); + } +}; + +class ConvertVPTOTypedCarrierOp final : public ConversionPattern { +public: + ConvertVPTOTypedCarrierOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (isa(op)) + return failure(); + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with VPTO types are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, + RewritePatternSet &patterns, + LoweringState &state) { + patterns.add, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerVsqzOpPattern, LowerVusqzOpPattern, + LowerVmulaOpPattern, LowerVmullOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerVdupOpPattern, + LowerVbrOpPattern, + LowerPredicatePackOpPattern, + LowerPredicatePackOpPattern, + LowerVselOpPattern, LowerVselrOpPattern, LowerPnotOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerUnpackOpPattern, + LowerUnpackOpPattern, + LowerVpackOpPattern, + LowerInterleaveOpPattern, + LowerInterleaveOpPattern, + LowerCmpOpPattern, + LowerCmpOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerPipeEventSyncOpPattern, + LowerPipeEventSyncOpPattern, + LowerBarrierOpPattern, + LowerBufSyncOpPattern, + LowerBufSyncOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerVldsOpPattern, LowerVldsPostOpPattern, + LowerVldsx2OpPattern, LowerVsldbOpPattern, + LowerVldasOpPattern, LowerInitAlignOpPattern, + LowerVldusOpPattern, LowerSprclrOpPattern, + LowerVstsOpPattern, LowerVsstbOpPattern, + LowerVstsPostOpPattern, LowerVstsx2OpPattern, + LowerVstarOpPattern, LowerVstasOpPattern, + LowerVgather2OpPattern, LowerVgather2BcOpPattern, + LowerVgatherbOpPattern, LowerVscatterOpPattern, + LowerVpreluOpPattern, LowerVaxpyOpPattern, + LowerVciOpPattern, LowerVexpdiffOpPattern, + LowerVbitsortOpPattern, LowerVtrcOpPattern, LowerVcvtOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateStoreOpPattern, + LowerPredicateStoreOpPattern, + LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, + LowerCopyOpPattern, + LowerCopyOpPattern>( + typeConverter, patterns.getContext(), state); +} + +static void configureVPTOOpLoweringTarget(ConversionTarget &target, + VPTOTypeConverter &typeConverter) { + (void)typeConverter; + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); +} + +static void populateVPTOStructuralTypePatterns( + VPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} + +static void foldVPTOTypeCasts(ModuleOp module, TypeConverter &typeConverter) { + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasVPTOConvertibleType(castOp->getOperandTypes()) && + !hasVPTOConvertibleType(castOp->getResultTypes())) + return; + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } +} + +static LogicalResult lowerVPTOOps(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + LoweringState state; + + configureVPTOOpLoweringTarget(target, typeConverter); + populateVPTOOpLoweringPatterns(typeConverter, patterns, state); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO op lowering failed\n"; + return failure(); + } + if (failed(materializeDecls(module, state.plannedDecls, diagOS))) + return failure(); + return success(); +} + +static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + populateVPTOStructuralTypePatterns(typeConverter, patterns, target); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO type lowering failed\n"; + return failure(); + } + foldVPTOTypeCasts(module, typeConverter); + return success(); +} + +static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { + type = convertVPTOType(type, builder); + return type; +} + +static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { + Builder builder(module.getContext()); + + for (func::FuncOp funcOp : module.getOps()) { + FunctionType oldType = funcOp.getFunctionType(); + SmallVector newInputs; + SmallVector newResults; + bool changed = false; + + for (Type input : oldType.getInputs()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(input, builder); + changed |= (normalized != input); + newInputs.push_back(normalized); + } + for (Type result : oldType.getResults()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(result, builder); + changed |= (normalized != result); + newResults.push_back(normalized); + } + + if (!changed) + continue; + + auto newType = builder.getFunctionType(newInputs, newResults); + funcOp.setFunctionTypeAttr(TypeAttr::get(newType)); + + if (funcOp.isExternal()) + continue; + Block &entry = funcOp.getBody().front(); + for (auto [arg, newType] : llvm::zip(entry.getArguments(), newInputs)) + if (arg.getType() != newType) + arg.setType(newType); + } +} + +template +static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, + const VPTOEmissionOptions &options, + EmitFn &&emit) { + OwningOpRef clonedOp(module->clone()); + ModuleOp clonedModule = cast(*clonedOp); + + materializeVecScopeCarrierLoops(clonedModule); + + if (failed(normalizePtoMemRefSpaces(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: normalizePtoMemRefSpaces failed\n"; + return failure(); + } + if (failed(lowerVPTOOps(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: lowerVPTOOps failed\n"; + return failure(); + } + if (failed(lowerVPTOTypes(clonedModule, diagOS))) { + diagOS << "VPTO LLVM emission failed: lowerVPTOTypes failed\n"; + return failure(); + } + + normalizeFuncSignaturesForOfficialLLVMLowering(clonedModule); + + PassManager pm(clonedModule.getContext()); + pm.enableVerifier(); + pm.addPass(createConvertSCFToCFPass()); + pm.addPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertIndexToLLVMPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(pm.run(clonedModule))) { + diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; + return failure(); + } + + if (failed(applyQueriedTargetAttrs(clonedModule, options, diagOS))) + return failure(); + + llvm::LLVMContext llvmContext; + registerBuiltinDialectTranslation(*clonedModule.getContext()); + registerLLVMDialectTranslation(*clonedModule.getContext()); + std::unique_ptr llvmModule = + translateModuleToLLVMIR(clonedModule.getOperation(), llvmContext); + if (!llvmModule) { + diagOS << "VPTO LLVM emission failed: LLVM IR export failed\n"; + return failure(); + } + + if (failed(attachAIVectorScopeMetadata(*llvmModule, diagOS))) + return failure(); + attachHIVMKernelAnnotations(*llvmModule); + llvmModule->setModuleIdentifier("ptoas.hivm.official"); + llvmModule->setSourceFileName("ptoas.hivm.official"); + return emit(*llvmModule); +} + +} // namespace + +LogicalResult +translateVPTOModuleToLLVMText(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + return runPipeline(module, diagOS, options, [&](llvm::Module &llvmModule) { + llvmModule.print(os, nullptr); + return success(); + }); +} + +LogicalResult +translateVPTOModuleToLLVMBitcode(ModuleOp module, llvm::raw_ostream &os, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + return runPipeline(module, diagOS, options, [&](llvm::Module &llvmModule) { + llvm::WriteBitcodeToFile(llvmModule, os); + return success(); + }); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp new file mode 100644 index 000000000..931f4966f --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -0,0 +1,684 @@ +//===- VPTOLLVMEmitterHelper.cpp - VPTO LLVM emission helpers ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VPTOLLVMEmitterHelper.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace mlir; + +namespace mlir::pto { +namespace { + +constexpr StringLiteral kAIVScopeDummyCallee = "aivscope_dummy"; + +struct QueriedTargetAttrs { + std::string targetCPU; + std::string targetFeatures; +}; + +static bool hasPtoMemRefMemorySpace(Type type) { + if (auto memRefType = dyn_cast(type)) + return isa(memRefType.getMemorySpace()); + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), hasPtoMemRefMemorySpace) || + llvm::any_of(functionType.getResults(), hasPtoMemRefMemorySpace); + return false; +} + +static bool hasPtoMemRefMemorySpace(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return hasPtoMemRefMemorySpace(type); + }); +} + +struct ConvertPtoMemRefSpaceCarrierOp final : ConversionPattern { + ConvertPtoMemRefSpaceCarrierOp(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with PTO memref spaces are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +struct ConvertMemRefReinterpretCastSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), adaptor.getOffsets(), + adaptor.getSizes(), adaptor.getStrides(), op.getStaticOffsets(), + op.getStaticSizes(), op.getStaticStrides()); + return success(); + } +}; + +struct ConvertMemRefSubViewSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + return success(); + } +}; + +struct ConvertMemRefSpaceUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } +}; + +static void ensureAIVScopeDummyDecl(ModuleOp module) { + SymbolTable symbolTable(module); + if (symbolTable.lookup(kAIVScopeDummyCallee)) + return; + + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(module.getBody()); + auto funcType = builder.getFunctionType(TypeRange{}, TypeRange{}); + auto dummy = builder.create(module.getLoc(), + kAIVScopeDummyCallee, funcType); + dummy.setPrivate(); +} + +static bool satisfiesAIVectorScopeLatchPostcondition(llvm::Loop *loop) { + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) + return false; + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.size() != 1) + return false; + + auto *predTerm = preds.front()->getTerminator(); + return predTerm && predTerm->getNumSuccessors() == 1 && + predTerm->getSuccessor(0) == latch; +} + +static LogicalResult ensureDummyPredForAIVectorScopeLatch( + llvm::Loop *loop, llvm::raw_ostream &diagOS) { + if (satisfiesAIVectorScopeLatchPostcondition(loop)) + return success(); + + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) { + diagOS << "VPTO LLVM emission failed: aivscope loop is missing a latch\n"; + return failure(); + } + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.empty()) { + diagOS << "VPTO LLVM emission failed: aivscope latch has no predecessor\n"; + return failure(); + } + + auto *dummy = llvm::SplitBlockPredecessors( + latch, preds, "aivscope.dummy", static_cast(nullptr), + static_cast(nullptr), nullptr, /*PreserveLCSSA=*/false); + if (!dummy) { + diagOS << "VPTO LLVM emission failed: failed to normalize aivscope latch " + "predecessors\n"; + return failure(); + } + + if (!satisfiesAIVectorScopeLatchPostcondition(loop)) { + diagOS << "VPTO LLVM emission failed: normalized aivscope latch still does " + "not satisfy the single-predecessor/single-successor contract\n"; + return failure(); + } + return success(); +} + +static FailureOr extractQuotedLLVMFnAttr(llvm::StringRef ir, + llvm::StringRef key) { + std::string pattern = "\""; + pattern += key.str(); + pattern += "\"=\""; + size_t start = ir.find(pattern); + if (start == llvm::StringRef::npos) + return failure(); + start += pattern.size(); + size_t end = ir.find('"', start); + if (end == llvm::StringRef::npos || end <= start) + return failure(); + return ir.slice(start, end).str(); +} + +static FailureOr +queryDefaultTargetAttrs(const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + static llvm::StringMap cache; + + if (options.targetTriple.empty() || options.march.empty() || + options.aicoreArch.empty()) { + diagOS << "VPTO LLVM emission failed: missing target query options\n"; + return failure(); + } + + std::string cacheKey = + options.targetTriple + "|" + options.march + "|" + options.aicoreArch; + if (auto it = cache.find(cacheKey); it != cache.end()) + return it->second; + + auto bisheng = llvm::sys::findProgramByName("bisheng"); + if (!bisheng) { + diagOS << "VPTO LLVM emission failed: unable to find 'bisheng' in PATH\n"; + return failure(); + } + const std::string &bishengPath = *bisheng; + + llvm::SmallString<64> inputPath; + llvm::SmallString<64> outputPath; + int inputFD = -1; + int outputFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "c", inputFD, inputPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query input: " + << ec.message() << "\n"; + return failure(); + } + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "ll", outputFD, outputPath)) { + llvm::sys::fs::remove(inputPath); + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + diagOS << "VPTO LLVM emission failed: cannot create bisheng query output: " + << ec.message() << "\n"; + return failure(); + } + + auto cleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(inputPath); + llvm::sys::fs::remove(outputPath); + }); + + { + llvm::raw_fd_ostream inputOS(inputFD, /*shouldClose=*/false); + inputOS << "void f(void) {}\n"; + } + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + llvm::sys::Process::SafelyCloseFileDescriptor(outputFD); + + llvm::SmallString<128> stderrPath; + int stderrFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "stderr", stderrFD, + stderrPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query stderr: " + << ec.message() << "\n"; + return failure(); + } + auto stderrCleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(stderrPath); + }); + llvm::sys::Process::SafelyCloseFileDescriptor(stderrFD); + + llvm::SmallVector argStorage = { + bishengPath, + ("--target=" + options.targetTriple), + ("-march=" + options.march), + ("--cce-aicore-arch=" + options.aicoreArch), + "--cce-aicore-only", + "-x", + "c", + inputPath.str().str(), + "-S", + "-emit-llvm", + "-o", + outputPath.str().str(), + }; + llvm::SmallVector args; + args.reserve(argStorage.size()); + for (const std::string &arg : argStorage) + args.push_back(arg); + + std::string execErr; + bool execFailed = false; + int rc = llvm::sys::ExecuteAndWait( + bishengPath, args, std::nullopt, + {std::nullopt, std::nullopt, llvm::StringRef(stderrPath)}, 0, 0, + &execErr, &execFailed); + + auto stderrBuffer = llvm::MemoryBuffer::getFile(stderrPath); + llvm::StringRef stderrText = + stderrBuffer ? stderrBuffer.get()->getBuffer() : llvm::StringRef(); + + if (execFailed || rc != 0) { + diagOS << "VPTO LLVM emission failed: bisheng target query failed\n"; + diagOS << "Command:"; + for (llvm::StringRef arg : args) + diagOS << " " << arg; + diagOS << "\n"; + if (!execErr.empty()) + diagOS << execErr << "\n"; + if (!stderrText.empty()) + diagOS << stderrText << "\n"; + return failure(); + } + + auto outputBuffer = llvm::MemoryBuffer::getFile(outputPath); + if (!outputBuffer) { + diagOS << "VPTO LLVM emission failed: cannot read bisheng query output\n"; + return failure(); + } + + FailureOr targetCPU = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-cpu"); + FailureOr targetFeatures = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-features"); + if (failed(targetCPU) || failed(targetFeatures)) { + diagOS << "VPTO LLVM emission failed: cannot parse bisheng target attrs\n"; + diagOS << outputBuffer.get()->getBuffer() << "\n"; + return failure(); + } + + QueriedTargetAttrs attrs{*targetCPU, *targetFeatures}; + cache[cacheKey] = attrs; + return attrs; +} + +} // namespace + +LogicalResult normalizePtoMemRefSpaces(ModuleOp module, + llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion([&](MemRefType type) -> Type { + auto addrSpace = dyn_cast_or_null(type.getMemorySpace()); + if (!addrSpace) + return type; + return MemRefType::get( + type.getShape(), type.getElementType(), type.getLayout(), + IntegerAttr::get(IntegerType::get(context, 64), + static_cast(addrSpace.getAddressSpace()))); + }); + typeConverter.addTypeAttributeConversion( + [](MemRefType, pto::AddressSpaceAttr attr) -> Attribute { + return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), + static_cast(attr.getAddressSpace())); + }); + auto materializeMemRefCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); + }; + typeConverter.addSourceMaterialization(materializeMemRefCast); + typeConverter.addTargetMaterialization(materializeMemRefCast); + typeConverter.addArgumentMaterialization(materializeMemRefCast); + + ConversionTarget target(*context); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: memref address-space normalization " + "failed\n"; + return failure(); + } + + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasPtoMemRefMemorySpace(castOp->getOperandTypes()) && + !hasPtoMemRefMemorySpace(castOp->getResultTypes())) + return; + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } + + WalkResult leftover = module.walk([&](Operation *op) { + if (hasPtoMemRefMemorySpace(op->getOperandTypes()) || + hasPtoMemRefMemorySpace(op->getResultTypes())) { + diagOS << "VPTO LLVM emission failed: residual PTO memref address space " + "on op " + << op->getName().getStringRef() << "\n"; + op->print(diagOS); + diagOS << "\n"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (leftover.wasInterrupted()) + return failure(); + return success(); +} + +void materializeVecScopeCarrierLoops(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + (void)ctx->getOrLoadDialect(); + (void)ctx->getOrLoadDialect(); + ensureAIVScopeDummyDecl(module); + + SmallVector scopes; + module.walk([&](pto::VecScopeOp vecScope) { scopes.push_back(vecScope); }); + + IRRewriter rewriter(module.getContext()); + for (pto::VecScopeOp vecScope : llvm::reverse(scopes)) { + if (!vecScope || vecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(vecScope); + auto loc = vecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &vecScopeBody = vecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + carrierBody->getOperations().splice(Block::iterator(yield), + vecScopeBody.getOperations(), + vecScopeBody.begin(), + vecScopeBody.end()); + rewriter.setInsertionPoint(yield); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + rewriter.eraseOp(vecScope); + } + + SmallVector strictScopes; + module.walk([&](pto::StrictVecScopeOp strictVecScope) { + strictScopes.push_back(strictVecScope); + }); + + for (pto::StrictVecScopeOp strictVecScope : llvm::reverse(strictScopes)) { + if (!strictVecScope || strictVecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(strictVecScope); + auto loc = strictVecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &strictBody = strictVecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + + IRMapping mapping; + for (auto [blockArg, capture] : + llvm::zip(strictBody.getArguments(), strictVecScope.getCaptures())) + mapping.map(blockArg, capture); + + rewriter.setInsertionPoint(yield); + for (Operation &nested : strictBody.getOperations()) + rewriter.clone(nested, mapping); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + + rewriter.eraseOp(strictVecScope); + } +} + +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS) { + llvm::Function *dummyCallee = llvmModule.getFunction(kAIVScopeDummyCallee); + if (!dummyCallee) + return success(); + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + llvm::DominatorTree dt(function); + llvm::LoopInfo loopInfo(dt); + + llvm::SmallVector dummyCalls; + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + auto *call = dyn_cast(&inst); + if (call && call->getCalledFunction() == dummyCallee) + dummyCalls.push_back(call); + } + } + + for (llvm::CallInst *dummyCall : dummyCalls) { + llvm::BasicBlock *markedBlock = dummyCall->getParent(); + llvm::Loop *loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() << " does not belong to an LLVM loop\n"; + return failure(); + } + + if (markedBlock == loop->getLoopLatch() && + dummyCall != markedBlock->getTerminator()) { + markedBlock->splitBasicBlock(dummyCall->getIterator(), "aivscope.latch"); + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + markedBlock = dummyCall->getParent(); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: split aivscope latch in " + << function.getName() + << " no longer belongs to an LLVM loop\n"; + return failure(); + } + } + + if (failed(ensureDummyPredForAIVectorScopeLatch(loop, diagOS))) + return failure(); + + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() + << " lost its loop after latch normalization\n"; + return failure(); + } + + llvm::BasicBlock *latch = loop->getLoopLatch(); + auto *branch = dyn_cast_or_null( + latch ? latch->getTerminator() : nullptr); + if (!branch || branch->isConditional()) { + diagOS << "VPTO LLVM emission failed: normalized aivscope loop in " + << function.getName() + << " does not have an unconditional latch backedge\n"; + return failure(); + } + + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Metadata *ops[] = { + nullptr, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, "llvm.loop.aivector_scope"))}; + auto *loopID = llvm::MDNode::getDistinct(ctx, ops); + loopID->replaceOperandWith(0, loopID); + branch->setMetadata(llvm::LLVMContext::MD_loop, loopID); + dummyCall->eraseFromParent(); + } + } + + if (dummyCallee->use_empty()) + dummyCallee->eraseFromParent(); + return success(); +} + +void attachHIVMKernelAnnotations(llvm::Module &llvmModule) { + llvm::NamedMDNode *annotations = + llvmModule.getOrInsertNamedMetadata("hivm.annotations"); + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx); + llvm::Constant *one = llvm::ConstantInt::get(i32Ty, 1); + + auto addAnnotation = [&](llvm::Function &function, llvm::StringRef kind) { + llvm::Metadata *ops[] = { + llvm::ValueAsMetadata::get(&function), + llvm::MDString::get(ctx, kind), + llvm::ConstantAsMetadata::get(one)}; + annotations->addOperand(llvm::MDNode::get(ctx, ops)); + }; + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + if (function.getLinkage() != llvm::GlobalValue::ExternalLinkage) + continue; + + llvm::StringRef name = function.getName(); + if (name.contains(".extracted") || name.contains(".vector.thread")) + continue; + + addAnnotation(function, "kernel"); + addAnnotation(function, "kernel_with_simd"); + } +} + +LogicalResult +applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + FailureOr attrs = queryDefaultTargetAttrs(options, diagOS); + if (failed(attrs)) { + if (options.defaultTargetCPU.empty() || + options.defaultTargetFeatures.empty()) + return failure(); + diagOS << "VPTO LLVM emission: falling back to configured default target " + "attributes\n"; + attrs = QueriedTargetAttrs{options.defaultTargetCPU, + options.defaultTargetFeatures}; + } + + MLIRContext *ctx = module.getContext(); + StringAttr cpuAttr = StringAttr::get(ctx, attrs->targetCPU); + LLVM::TargetFeaturesAttr featureAttr = + LLVM::TargetFeaturesAttr::get(ctx, attrs->targetFeatures); + module.walk([&](LLVM::LLVMFuncOp funcOp) { + funcOp.setTargetCpuAttr(cpuAttr); + funcOp.setTargetFeaturesAttr(featureAttr); + }); + return success(); +} + +} // namespace mlir::pto diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 06ee9843c..1dee91ac6 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -506,3 +506,875 @@ def _install_op_aliases(): return added __all__.extend(_install_op_aliases()) + +# ----------------------------------------------------------------------------- +# Experimental VPTO Python DSL (`@pto.vkernel`) +# ----------------------------------------------------------------------------- +import ast as _ast +import inspect as _inspect +import textwrap as _textwrap +from dataclasses import dataclass as _dataclass + + +class _VKernelType: + def render(self): + raise NotImplementedError + + +@_dataclass(frozen=True) +class _VKernelScalarType(_VKernelType): + name: str + + def render(self): + return self.name + + +@_dataclass(frozen=True) +class _VKernelPtrType(_VKernelType): + elem: _VKernelType + space: str + + def render(self): + return f"!pto.ptr<{self.elem.render()}, {self.space}>" + + +@_dataclass(frozen=True) +class _VKernelVRegType(_VKernelType): + lanes: int + elem: _VKernelType + + def render(self): + return f"!pto.vreg<{self.lanes}x{self.elem.render()}>" + + +@_dataclass(frozen=True) +class _VKernelConstBinding: + value: object + + +@_dataclass(frozen=True) +class _VKernelStructDef(_VKernelType): + name: str + fields: tuple + + def render(self): + raise _VKernelCompileError(f"{self.name} is a template-only surface type; use .jit(...) to specialize it") + + def __call__(self, **kwargs): + return _VKernelStructBinding(self, dict(kwargs)) + + +@_dataclass(frozen=True) +class _VKernelStructBinding: + schema: _VKernelStructDef + values: dict + + +@_dataclass(frozen=True) +class _VKStaticSequence: + values: tuple + + +@_dataclass(frozen=True) +class _VKStructValue: + schema: _VKernelStructDef + fields: dict + + +i1 = _VKernelScalarType("i1") +i8 = _VKernelScalarType("i8") +i16 = _VKernelScalarType("i16") +i32 = _VKernelScalarType("i32") +i64 = _VKernelScalarType("i64") +f16 = _VKernelScalarType("f16") +bf16 = _VKernelScalarType("bf16") +f32 = _VKernelScalarType("f32") +_vk_index = _VKernelScalarType("index") +mask = _VKernelScalarType("!pto.mask") +align = _VKernelScalarType("!pto.align") + + +def ptr(elem_type, space): + return _VKernelPtrType(elem_type, space) + + +def vreg(lanes, elem_type): + return _VKernelVRegType(lanes, elem_type) + + +def const(value): + return _VKernelConstBinding(value) + + +def struct(cls): + annotations = dict(getattr(cls, "__annotations__", {})) + if not annotations: + raise _VKernelCompileError("@pto.struct requires annotated fields") + fields = [] + for name, field_ty in annotations.items(): + if field_ty not in (ptr, const): + raise _VKernelCompileError( + f"unsupported field annotation for {cls.__name__}.{name}: {field_ty!r}" + ) + fields.append((name, field_ty)) + return _VKernelStructDef(cls.__name__, tuple(fields)) + + +@struct +class Tile: + ub_ptr: ptr + shape: const + + +tile = Tile + + +class _VKernelCompileError(Exception): + pass + + +@_dataclass +class _VKValue: + name: str | None = None + type: _VKernelType | None = None + literal: object | None = None + + def render_type(self): + if self.type is None: + raise _VKernelCompileError(f"unresolved type for {self.name}") + return self.type.render() + + +def _project_result(group, index, ty): + return _VKValue(f"{group.name}#{index}", ty) + + +def _load_standard_dialects(): + try: + from mlir.dialects import arith as _mlir_arith # noqa: F401 + from mlir.dialects import func as _mlir_func # noqa: F401 + from mlir.dialects import scf as _mlir_scf # noqa: F401 + except ImportError as exc: + raise RuntimeError("mlir standard dialect python bindings are required for vkernel parsing") from exc + + +class _VKernelContext: + def __init__(self): + self.ssa_counter = 0 + self.arg_counter = 0 + + def new_ssa(self): + name = f"%{self.ssa_counter}" + self.ssa_counter += 1 + return name + + def new_arg(self): + name = f"%arg{self.arg_counter}" + self.arg_counter += 1 + return name + + +def _type_key(ty): + return ty.render() if ty is not None else None + + +def _types_equal(lhs, rhs): + if lhs is None or rhs is None: + return lhs is rhs + return lhs.render() == rhs.render() + + +def _ensure_type(value, expected): + if value.type is None: + value.type = expected + return + if not _types_equal(value.type, expected): + raise _VKernelCompileError( + f"type mismatch for {value.name}: expected {expected.render()}, got {value.type.render()}" + ) + + +def _literal_text(value): + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + +def _coerce_surface_type(value): + if value is bool: + return i1 + if value is float: + return f32 + return value + + +def _ptr_elem_bytes(ptr_type): + if not isinstance(ptr_type, _VKernelPtrType): + raise _VKernelCompileError("elem_bytes requires a ptr type") + elem_name = ptr_type.elem.render() + table = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f16": 2, + "bf16": 2, + "f32": 4, + } + if elem_name not in table: + raise _VKernelCompileError(f"unsupported elem_bytes for {elem_name}") + return table[elem_name] + + +def _ptr_vector_lanes(ptr_type): + return 256 // _ptr_elem_bytes(ptr_type) + + +class _VKernelBuilder: + def __init__(self, py_fn, fn_def, target, kernel_name, specialization=None): + self.py_fn = py_fn + self.fn_def = fn_def + self.target = target + self.kernel_name = kernel_name + self.ctx = _VKernelContext() + self.specialization = specialization or {} + + def _emit(self, lines, indent, text): + lines.append(" " * indent + text) + + def _eval_type_expr(self, node): + expr = _ast.Expression(body=node) + globals_dict = dict(self.py_fn.__globals__) + globals_dict.update(globals()) + value = eval(compile(expr, self.py_fn.__code__.co_filename, "eval"), + globals_dict, {}) + value = _coerce_surface_type(value) + if not isinstance(value, _VKernelType): + raise _VKernelCompileError(f"unsupported vkernel type annotation: {value!r}") + return value + + def _new_value(self, ty=None): + return _VKValue(self.ctx.new_ssa(), ty) + + def _new_arg_value(self, ty=None): + return _VKValue(self.ctx.new_arg(), ty) + + def _materialize_value(self, value, lines, indent, expected_type=None): + if expected_type is not None: + _ensure_type(value, expected_type) + if value.name is not None: + return value + if value.literal is None: + raise _VKernelCompileError("value has no SSA name and cannot be materialized") + if value.type is None: + raise _VKernelCompileError("literal requires type context") + value.name = self.ctx.new_ssa() + lit = _literal_text(value.literal) + if isinstance(value.literal, bool): + self._emit(lines, indent, f"{value.name} = arith.constant {lit}") + else: + self._emit(lines, indent, f"{value.name} = arith.constant {lit} : {value.type.render()}") + return value + + def _literal_value(self, node, lines, indent, expected_type): + value = _VKValue(type=expected_type, literal=node.value) + if expected_type is None: + return value + return self._materialize_value(value, lines, indent) + + def _lower_attribute(self, node, env, lines, indent, expected_type=None): + if isinstance(node.value, _ast.Name): + if node.value.id not in env: + raise _VKernelCompileError(f"unknown name '{node.value.id}'") + base = env[node.value.id] + else: + base = self._lower_expr(node.value, env, lines, indent) + if isinstance(base, _VKStructValue): + if node.attr not in base.fields: + raise _VKernelCompileError(f"unsupported struct attribute '{node.attr}'") + field = base.fields[node.attr] + if isinstance(field, _VKValue): + return self._materialize_value(field, lines, indent, expected_type) + return field + if isinstance(base, _VKValue) and isinstance(base.type, _VKernelPtrType): + if node.attr == "elem_bytes": + return _VKValue(type=expected_type, literal=_ptr_elem_bytes(base.type)) + raise _VKernelCompileError(f"unsupported attribute access '{node.attr}'") + + def _lower_subscript(self, node, env, lines, indent, expected_type=None): + base = self._lower_expr(node.value, env, lines, indent) + if not isinstance(base, _VKStaticSequence): + raise _VKernelCompileError("subscript base must be a static sequence") + if not isinstance(node.slice, _ast.Constant) or not isinstance(node.slice.value, int): + raise _VKernelCompileError("only constant integer subscripts are supported") + index = node.slice.value + if index < 0 or index >= len(base.values): + raise _VKernelCompileError("subscript out of range") + value = base.values[index] + if not isinstance(value, _VKValue): + value = _VKValue(type=expected_type, literal=value) + return self._materialize_value(value, lines, indent, expected_type) if expected_type is not None else value + + def _lower_binop(self, node, env, lines, indent, expected_type=None): + lhs = self._lower_expr(node.left, env, lines, indent) + rhs = self._lower_expr(node.right, env, lines, indent) + if lhs.literal is not None and rhs.literal is not None: + if isinstance(node.op, _ast.Mult): + result = lhs.literal * rhs.literal + elif isinstance(node.op, _ast.FloorDiv): + result = lhs.literal // rhs.literal + else: + raise _VKernelCompileError(f"unsupported binary operator: {type(node.op).__name__}") + return _VKValue(type=expected_type, literal=result) + raise _VKernelCompileError("non-constant binary expressions are not supported yet") + + def _lower_expr(self, node, env, lines, indent, expected_type=None): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + if isinstance(value, (_VKStructValue, _VKStaticSequence)): + raise _VKernelCompileError(f"name '{node.id}' is not a scalar/SSA value") + if ( + isinstance(value, _VKValue) + and value.name is None + and value.literal is not None + and expected_type is not None + ): + return self._materialize_value( + _VKValue(type=expected_type, literal=value.literal), + lines, + indent, + ) + return self._materialize_value(value, lines, indent, expected_type) + if isinstance(node, _ast.Constant): + return self._literal_value(node, lines, indent, expected_type) + if isinstance(node, _ast.Attribute): + return self._lower_attribute(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Subscript): + return self._lower_subscript(node, env, lines, indent, expected_type) + if isinstance(node, _ast.BinOp): + return self._lower_binop(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Call): + results = self._lower_call(node, env, lines, indent, expected_types=[expected_type] if expected_type else None) + if len(results) != 1: + raise _VKernelCompileError("expression expected single result") + return results[0] + raise _VKernelCompileError(f"unsupported expression: {type(node).__name__}") + + def _lower_call_name(self, node): + if isinstance(node, _ast.Attribute) and isinstance(node.value, _ast.Name) and node.value.id == "pto": + return node.attr + raise _VKernelCompileError("only pto.* calls are supported") + + def _infer_expr_type(self, node, env): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Attribute): + try: + value = self._lower_attribute(node, env, [], 0) + except _VKernelCompileError: + return None + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Constant): + return None + return None + + def _format_typed_operands(self, values): + return ", ".join(v.name for v in values), ", ".join(v.render_type() for v in values) + + def _lower_call(self, node, env, lines, indent, expected_types=None): + opname = self._lower_call_name(node.func) + + if opname in ("set_loop_size_outtoub", "set_loop_size_ubtoout"): + ops = [self._lower_expr(arg, env, lines, indent, i64) for arg in node.args] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.{opname} {operands} : {types}") + return [] + + if opname == "castptr": + if len(node.args) != 2: + raise _VKernelCompileError("pto.castptr expects 2 arguments") + result_type = self._eval_type_expr(node.args[1]) + addr = self._lower_expr(node.args[0], env, lines, indent, i64) + result = self._new_value(result_type) + self._emit(lines, indent, f"{result.name} = pto.castptr {addr.name} : {addr.render_type()} -> {result.render_type()}") + return [result] + + if opname == "copy_gm_to_ubuf": + expected = [None, None, i64, i64, i64, i64, i64, i1, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_gm_to_ubuf {operands} : {types}") + return [] + + if opname == "copy_ubuf_to_gm": + expected = [None, None, i64, i64, i64, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_ubuf_to_gm {operands} : {types}") + return [] + + if opname in ("set_flag", "wait_flag"): + attrs = [] + for arg in node.args: + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError(f"pto.{opname} expects string literals") + attrs.append(arg.value) + self._emit(lines, indent, f'pto.{opname}["{attrs[0]}", "{attrs[1]}", "{attrs[2]}"]') + return [] + + if opname == "barrier": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.barrier expects a string literal") + self._emit(lines, indent, f"pto.barrier #pto.pipe<{arg.value}>") + return [] + + if opname == "plt_b32": + src = self._lower_expr(node.args[0], env, lines, indent, i32) + res0 = self._new_value(mask) + res1 = self._new_value(i32) + self._emit(lines, indent, f"{res0.name}, {res1.name} = pto.plt_b32 {src.name} : i32 -> !pto.mask, i32") + return [res0, res1] + + if opname == "pset_b32": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.pset_b32 expects a string literal") + res = self._new_value(mask) + self._emit(lines, indent, f'{res.name} = pto.pset_b32 "{arg.value}" : !pto.mask') + return [res] + + if opname == "vlds": + ptr_value = self._lower_expr(node.args[0], env, lines, indent) + if not isinstance(ptr_value.type, _VKernelPtrType): + raise _VKernelCompileError("pto.vlds expects a ptr operand") + offset = self._lower_expr(node.args[1], env, lines, indent, _vk_index) + result = self._new_value(vreg(_ptr_vector_lanes(ptr_value.type), ptr_value.type.elem)) + self._emit(lines, indent, + f"{result.name} = pto.vlds {ptr_value.name}[{offset.name}] : {ptr_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vabs": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + mask_value = self._lower_expr(node.args[1], env, lines, indent, mask) + result = self._new_value(vec_value.type) + self._emit(lines, indent, + f"{result.name} = pto.vabs {vec_value.name}, {mask_value.name} : {vec_value.render_type()}, {mask_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vsts": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + ptr_value = self._lower_expr(node.args[1], env, lines, indent) + offset = self._lower_expr(node.args[2], env, lines, indent, _vk_index) + mask_value = self._lower_expr(node.args[3], env, lines, indent, mask) + self._emit(lines, indent, + f"pto.vsts {vec_value.name}, {ptr_value.name}[{offset.name}], {mask_value.name} : {vec_value.render_type()}, {ptr_value.render_type()}, {mask_value.render_type()}") + return [] + + raise _VKernelCompileError(f"unsupported pto op in vkernel: {opname}") + + def _collect_assigned_names(self, statements): + names = set() + + class Visitor(_ast.NodeVisitor): + def visit_Assign(self, node): + for target in node.targets: + self._collect_target(target) + + def _collect_target(self, target): + if isinstance(target, _ast.Name): + names.add(target.id) + elif isinstance(target, _ast.Tuple): + for elt in target.elts: + self._collect_target(elt) + + visitor = Visitor() + for stmt in statements: + if isinstance(stmt, (_ast.With, _ast.For, _ast.If)): + continue + visitor.visit(stmt) + return names + + def _compile_block(self, statements, env, indent): + lines = [] + current_env = dict(env) + + for stmt in statements: + if isinstance(stmt, _ast.Assign): + if len(stmt.targets) != 1: + raise _VKernelCompileError("multiple assignment targets are not supported") + target = stmt.targets[0] + if isinstance(target, _ast.Name): + value = self._lower_expr(stmt.value, current_env, lines, indent) + current_env[target.id] = value + elif isinstance(target, _ast.Tuple): + results = self._lower_call(stmt.value, current_env, lines, indent) + if len(results) != len(target.elts): + raise _VKernelCompileError("tuple assignment arity mismatch") + for elt, value in zip(target.elts, results): + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("tuple assignment only supports names") + current_env[elt.id] = value + else: + raise _VKernelCompileError("unsupported assignment target") + continue + + if isinstance(stmt, _ast.AnnAssign): + if stmt.value is None: + raise _VKernelCompileError("annotation-only assignment is not supported") + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("annotated assignment only supports names") + target_type = self._eval_type_expr(stmt.annotation) + value = self._lower_expr(stmt.value, current_env, lines, indent, target_type) + current_env[stmt.target.id] = value + continue + + if isinstance(stmt, _ast.Expr): + if isinstance(stmt.value, _ast.Call): + self._lower_call(stmt.value, current_env, lines, indent) + else: + self._lower_expr(stmt.value, current_env, lines, indent) + continue + + if isinstance(stmt, _ast.Return): + if stmt.value is not None: + raise _VKernelCompileError("only empty return is supported") + self._emit(lines, indent, "return") + continue + + if isinstance(stmt, _ast.With): + if len(stmt.items) != 1: + raise _VKernelCompileError("only single with item is supported") + item = stmt.items[0] + name = self._lower_call_name(item.context_expr.func) + if name not in ("strict_vecscope", "vecscope"): + raise _VKernelCompileError("unsupported with context") + if name == "strict_vecscope": + body_lines, body_result = self._compile_strict_vecscope(item, stmt.body, current_env, indent) + else: + body_lines, body_result = self._compile_vecscope(stmt.body, current_env, indent) + lines.extend(body_lines) + current_env.update(body_result) + continue + + if isinstance(stmt, _ast.For): + loop_lines, updated_env = self._compile_for(stmt, current_env, indent) + lines.extend(loop_lines) + current_env = updated_env + continue + + if isinstance(stmt, _ast.If): + if_lines, updated_env = self._compile_if(stmt, current_env, indent) + lines.extend(if_lines) + current_env = updated_env + continue + + raise _VKernelCompileError(f"unsupported statement: {type(stmt).__name__}") + + return lines, current_env + + def _compile_vecscope(self, body, outer_env, indent): + body_lines, _ = self._compile_block(body, dict(outer_env), indent + 1) + lines = [] + self._emit(lines, indent, "pto.vecscope {") + lines.extend(body_lines) + self._emit(lines, indent, "}") + return lines, {} + + def _compile_strict_vecscope(self, item, body, outer_env, indent): + if not isinstance(item.optional_vars, _ast.Tuple): + raise _VKernelCompileError("pto.strict_vecscope requires tuple binding in 'as'") + if len(item.context_expr.args) != len(item.optional_vars.elts): + raise _VKernelCompileError("strict_vecscope capture arity must match bound block arguments") + arg_names = [] + inner_env = {} + for elt in item.optional_vars.elts: + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("pto.strict_vecscope bindings must be names") + arg = self._new_arg_value() + arg_names.append((elt.id, arg)) + inner_env[elt.id] = arg + + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + inferred_type = self._infer_expr_type(expr, outer_env) + if inferred_type is not None: + arg.type = inferred_type + + lines = [] + body_lines, body_env = self._compile_block(body, inner_env, indent + 1) + captures = [] + for name, arg in arg_names: + if arg.type is None and name in body_env and body_env[name].type is not None: + arg.type = body_env[name].type + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + if arg.type is None: + raise _VKernelCompileError("strict_vecscope block argument type could not be inferred") + capture = self._lower_expr(expr, outer_env, lines, indent, expected_type=arg.type) + captures.append(capture) + capture_operands = ", ".join(value.name for value in captures) + block_args = ", ".join(f"{arg.name}: {arg.render_type()}" for _, arg in arg_names) + func_type = ", ".join(arg.render_type() for _, arg in arg_names) + + self._emit(lines, indent, f"pto.strict_vecscope({capture_operands}) {{") + self._emit(lines, indent, f"^bb0({block_args}):") + lines.extend(body_lines) + self._emit(lines, indent, f"}} : ({func_type}) -> ()") + return lines, {} + + def _compile_for(self, stmt, outer_env, indent): + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("for target must be a single name") + if not isinstance(stmt.iter, _ast.Call) or not isinstance(stmt.iter.func, _ast.Name) or stmt.iter.func.id != "range": + raise _VKernelCompileError("only Python range(...) loops are supported") + if len(stmt.iter.args) != 3: + raise _VKernelCompileError("range expects exactly 3 arguments in vkernel") + + lines = [] + lb = self._lower_expr(stmt.iter.args[0], outer_env, lines, indent, _vk_index) + ub = self._lower_expr(stmt.iter.args[1], outer_env, lines, indent, _vk_index) + step = self._lower_expr(stmt.iter.args[2], outer_env, lines, indent, _vk_index) + + loop_env = dict(outer_env) + iv = self._new_arg_value(_vk_index) + loop_env[stmt.target.id] = iv + candidate_carried = [] + for name in self._collect_assigned_names(stmt.body): + if name in outer_env and name != stmt.target.id: + iter_arg = self._new_arg_value(outer_env[name].type) + loop_env[name] = iter_arg + candidate_carried.append((name, outer_env[name], iter_arg)) + + body_lines, body_env = self._compile_block(stmt.body, loop_env, indent + 1) + carried = [] + for name, before, iter_arg in candidate_carried: + after = body_env.get(name) + if after is not None and after is not iter_arg: + carried.append((name, before, after)) + + result_prefix = "" + yield_line = None + if carried: + results = [after.render_type() for _, _, after in carried] + result_value = self._new_value() + result_prefix = f"{result_value.name}:{len(carried)} = " + iter_arg_map = {name: iter_arg for name, _, iter_arg in candidate_carried} + carried_with_initials = [] + for name, before, after in carried: + before = self._materialize_value(before, lines, indent, after.type) + carried_with_initials.append((name, before, after)) + carried = carried_with_initials + iter_args = ", ".join( + f"{iter_arg_map[name].name} = {before.name}" for name, before, _ in carried + ) + self._emit( + lines, + indent, + f"{result_prefix}scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} iter_args({iter_args}) -> ({', '.join(results)}) {{", + ) + yield_line = f"scf.yield {', '.join(after.name for _, _, after in carried)} : {', '.join(results)}" + else: + self._emit(lines, indent, f"scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} {{") + lines.extend(body_lines) + if yield_line: + self._emit(lines, indent + 1, yield_line) + self._emit(lines, indent, "}") + + updated_env = dict(outer_env) + if carried: + for idx, (name, _, after) in enumerate(carried): + updated_env[name] = _project_result(result_value, idx, after.type) + return lines, updated_env + + def _compile_if(self, stmt, outer_env, indent): + lines = [] + cond = self._lower_expr(stmt.test, outer_env, lines, indent, i1) + then_lines, then_env = self._compile_block(stmt.body, dict(outer_env), indent + 1) + else_lines, else_env = self._compile_block(stmt.orelse, dict(outer_env), indent + 1) + updated = [] + for name, before in outer_env.items(): + then_val = then_env.get(name, before) + else_val = else_env.get(name, before) + if then_val is not before or else_val is not before: + if not _types_equal(then_val.type, else_val.type): + raise _VKernelCompileError(f"if merge type mismatch for '{name}'") + updated.append((name, then_val, else_val)) + + if updated: + result = self._new_value() + types = ", ".join(val.type.render() for _, val, _ in updated) + self._emit(lines, indent, f"{result.name}:{len(updated)} = scf.if {cond.name} -> ({types}) {{") + lines.extend(then_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, val, _ in updated)} : {types}") + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, _, val in updated)} : {types}") + self._emit(lines, indent, "}") + updated_env = dict(outer_env) + for idx, (name, then_val, _) in enumerate(updated): + updated_env[name] = _project_result(result, idx, then_val.type) + return lines, updated_env + + self._emit(lines, indent, f"scf.if {cond.name} {{") + lines.extend(then_lines) + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent, "}") + return lines, dict(outer_env) + + def build_text(self): + lines = [f'module attributes {{pto.target_arch = "{self.target}"}} {{'] + arg_types = [] + env = {} + for arg in self.fn_def.args.args: + arg_ty = _coerce_surface_type(self.py_fn.__annotations__.get(arg.arg)) + if arg_ty is None: + raise _VKernelCompileError(f"missing type annotation for argument '{arg.arg}'") + if not isinstance(arg_ty, _VKernelType): + raise _VKernelCompileError(f"unsupported type annotation for argument '{arg.arg}'") + if isinstance(arg_ty, _VKernelStructDef): + if arg.arg not in self.specialization: + raise _VKernelCompileError( + f"template argument '{arg.arg}: {arg_ty.name}' requires .jit(...) specialization" + ) + binding = self.specialization[arg.arg] + if not isinstance(binding, _VKernelStructBinding) or binding.schema != arg_ty: + raise _VKernelCompileError( + f"specialization for '{arg.arg}' must be a {arg_ty.name}(...) binding" + ) + struct_fields = {} + for field_name, field_kind in arg_ty.fields: + if field_name not in binding.values: + raise _VKernelCompileError( + f"missing field '{field_name}' in specialization for '{arg.arg}'" + ) + field_value = binding.values[field_name] + if field_kind is ptr: + if not isinstance(field_value, _VKernelPtrType): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a pto.ptr(...) type object" + ) + arg_val = self._new_arg_value(field_value) + arg_types.append(f"{arg_val.name}: {field_value.render()}") + struct_fields[field_name] = arg_val + continue + if field_kind is const: + if not isinstance(field_value, _VKernelConstBinding): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must use pto.const(...)" + ) + static_value = field_value.value + if not isinstance(static_value, (list, tuple)) or not all( + isinstance(v, int) for v in static_value + ): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a list/tuple of ints" + ) + struct_fields[field_name] = _VKStaticSequence( + tuple(_VKValue(literal=v) for v in static_value) + ) + continue + raise _VKernelCompileError( + f"unsupported struct field kind for {arg_ty.name}.{field_name}" + ) + env[arg.arg] = _VKStructValue(arg_ty, struct_fields) + continue + arg_val = self._new_arg_value(arg_ty) + arg_types.append(f"{arg_val.name}: {arg_ty.render()}") + env[arg.arg] = arg_val + self._emit(lines, 1, f"func.func @{self.kernel_name}({', '.join(arg_types)}) {{") + body_lines, _ = self._compile_block(self.fn_def.body, env, 2) + lines.extend(body_lines) + if not any(line.strip() == "return" for line in body_lines): + self._emit(lines, 2, "return") + self._emit(lines, 1, "}") + lines.append("}") + return "\n".join(lines) + "\n" + + +class VKernelHandle: + def __init__(self, py_fn, target="a5", name=None, verify=True, specialization=None): + self._py_fn = py_fn + self._target = target + self._name = name or py_fn.__name__ + self._verify = verify + self._specialization = specialization or {} + self._cached_text = None + + def _load_ast(self): + source = _textwrap.dedent(_inspect.getsource(self._py_fn)) + module = _ast.parse(source) + for node in module.body: + if isinstance(node, _ast.FunctionDef) and node.name == self._py_fn.__name__: + return node + raise _VKernelCompileError(f"failed to locate function AST for {self._py_fn.__name__}") + + def mlir_text(self): + if self._cached_text is None: + builder = _VKernelBuilder( + self._py_fn, + self._load_ast(), + self._target, + self._name, + specialization=self._specialization, + ) + self._cached_text = builder.build_text() + return self._cached_text + + def mlir_module(self): + with _ods_ir.Context() as ctx: + _load_standard_dialects() + register_dialect(ctx, load=True) + return _ods_ir.Module.parse(self.mlir_text(), ctx) + + def verify(self): + mod = self.mlir_module() + mod.operation.verify() + return True + + def dump(self): + print(self.mlir_text(), end="") + + def emit(self, path): + with open(path, "w", encoding="utf-8") as f: + f.write(self.mlir_text()) + + def jit(self, **kwargs): + return VKernelHandle( + self._py_fn, + target=self._target, + name=self._name, + verify=self._verify, + specialization=kwargs, + ) + + def __str__(self): + return self.mlir_text() + + +def vkernel(py_fn=None, *, target="a5", name=None, verify=True): + def wrap(fn): + return VKernelHandle(fn, target=target, name=name, verify=verify) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +__all__.extend([ + "vkernel", + "VKernelHandle", + "struct", + "Tile", + "tile", + "const", + "ptr", + "vreg", + "i1", "i8", "i16", "i32", "i64", + "f16", "bf16", "f32", + "mask", "align", +]) diff --git a/scripts/batch_compile_output_cpp.sh b/scripts/batch_compile_output_cpp.sh new file mode 100755 index 000000000..13426a9f6 --- /dev/null +++ b/scripts/batch_compile_output_cpp.sh @@ -0,0 +1,464 @@ +#!/usr/bin/env bash + +set -u + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +DEFAULT_SOURCE_DIR="${PTO_SOURCE_DIR:-${REPO_ROOT}}" +SRC_ROOT="${PTOAS_OUT_DIR:-${DEFAULT_SOURCE_DIR}/build/output}" +BUILD_ROOT="${DEFAULT_SOURCE_DIR}/build/output_asm" +LOG_DIR="${DEFAULT_SOURCE_DIR}/build/output_log" + +COMPILER="${COMPILER:-}" +PTO_ISA_PATH="${PTO_ISA_PATH:-${PTO_ISA_ROOT:-}}" +EXTRA_ARGS=() + +JOBS="${JOBS:-$(nproc)}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +MEM_BASE_DEFINE="${MEM_BASE_DEFINE:-REGISTER_BASE}" +ENABLE_DEFAULT_ARGS=1 + +print_usage() { + cat <<'EOF' +批量编译 output 目录下所有 .cpp 文件为 .S,并汇总结果。 + +用法: + scripts/batch_compile_output_cpp.sh \ + [--compiler <编译器路径>] \ + [--pto-isa-path ] \ + [--compile-arg <单个参数>]... \ + [--jobs <并行数>] \ + [--aicore-arch ] \ + [--mem-base-define <宏名>] \ + [--src-root <源码目录>] \ + [--build-root <产物目录>] \ + [--log-dir <日志目录>] + +参数说明: + --compiler, -c 编译器路径。默认优先使用环境变量 COMPILER, + 其次使用 PATH 中的 bisheng 或 + ${ASCEND_HOME_PATH}/bin/bisheng + --pto-isa-path, -p PTO-ISA 根路径。默认优先使用环境变量 + PTO_ISA_PATH / PTO_ISA_ROOT。脚本会自动检测 include 目录: + 1) /include + 2) /tests/common (存在时自动加入) + 3) + --compile-arg 额外编译参数,可重复传入 + --jobs, -j 并行编译任务数,默认: nproc + --aicore-arch 默认: dav-c220-vec + --mem-base-define 默认: MEMORY_BASE (可改为 REGISTER_BASE) + --no-default-args 不使用脚本内置默认参数(仅使用 --compile-arg) + --src-root 要扫描的 .cpp 根目录,默认: $PTOAS_OUT_DIR + 或 $PTO_SOURCE_DIR/build/output + --build-root .S 产物目录,默认: $PTO_SOURCE_DIR/build/output_asm + --log-dir 编译日志目录,默认: /logs + --help, -h 显示帮助 + +推荐先执行: + source scripts/ptoas_env.sh + +默认编译参数来源: + 由 test/npu_validation/scripts/generate_testcase.py 中 + CMAKE_CCE_COMPILE_OPTIONS + target_compile_options() 提取: + -xcce -fenable-matrix --cce-aicore-enable-tl -fPIC -Xhost-start -Xhost-end + -mllvm -cce-aicore-function-stack-size=0x8000 + -mllvm -cce-aicore-record-overflow=true + -mllvm -cce-aicore-addr-transform + -mllvm -cce-aicore-dcci-insert-for-scalar=false + --cce-aicore-arch= -D -std=c++17 +EOF +} + +die() { + echo "[ERROR] $*" >&2 + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --compiler | -c) + [[ $# -ge 2 ]] || die "--compiler 缺少参数" + COMPILER="$2" + shift 2 + ;; + --pto-isa-path | -p) + [[ $# -ge 2 ]] || die "--pto-isa-path 缺少参数" + PTO_ISA_PATH="$2" + shift 2 + ;; + --compile-arg) + [[ $# -ge 2 ]] || die "--compile-arg 缺少参数" + EXTRA_ARGS+=("$2") + shift 2 + ;; + --jobs | -j) + [[ $# -ge 2 ]] || die "--jobs 缺少参数" + JOBS="$2" + shift 2 + ;; + --aicore-arch) + [[ $# -ge 2 ]] || die "--aicore-arch 缺少参数" + AICORE_ARCH="$2" + shift 2 + ;; + --mem-base-define) + [[ $# -ge 2 ]] || die "--mem-base-define 缺少参数" + MEM_BASE_DEFINE="$2" + shift 2 + ;; + --no-default-args) + ENABLE_DEFAULT_ARGS=0 + shift + ;; + --src-root) + [[ $# -ge 2 ]] || die "--src-root 缺少参数" + SRC_ROOT="$2" + shift 2 + ;; + --build-root) + [[ $# -ge 2 ]] || die "--build-root 缺少参数" + BUILD_ROOT="$2" + shift 2 + ;; + --log-dir) + [[ $# -ge 2 ]] || die "--log-dir 缺少参数" + LOG_DIR="$2" + shift 2 + ;; + --help | -h) + print_usage + exit 0 + ;; + *) + die "未知参数: $1 (使用 --help 查看用法)" + ;; + esac +done + +if [[ -z "${COMPILER}" ]]; then + if command -v bisheng >/dev/null 2>&1; then + COMPILER="$(command -v bisheng)" + elif [[ -n "${ASCEND_HOME_PATH:-}" && -x "${ASCEND_HOME_PATH}/bin/bisheng" ]]; then + COMPILER="${ASCEND_HOME_PATH}/bin/bisheng" + fi +elif [[ "${COMPILER}" != */* ]] && command -v "${COMPILER}" >/dev/null 2>&1; then + COMPILER="$(command -v "${COMPILER}")" +fi + +[[ -n "${COMPILER}" ]] || die "未找到编译器,请先 source scripts/ptoas_env.sh,或通过 --compiler/COMPILER 指定 bisheng 路径" +[[ -n "${PTO_ISA_PATH}" ]] || die "未找到 PTO-ISA 路径,请通过 --pto-isa-path、PTO_ISA_PATH 或 PTO_ISA_ROOT 指定" +[[ -x "${COMPILER}" ]] || die "编译器不可执行: ${COMPILER}" +[[ -d "${SRC_ROOT}" ]] || die "源码目录不存在: ${SRC_ROOT}" +[[ -d "${PTO_ISA_PATH}" ]] || die "PTO-ISA 路径不存在: ${PTO_ISA_PATH}" +[[ "${JOBS}" =~ ^[1-9][0-9]*$ ]] || die "--jobs 必须为正整数" + +if [[ -z "${LOG_DIR}" ]]; then + LOG_DIR="${BUILD_ROOT}/logs" +fi + +mkdir -p "${BUILD_ROOT}" "${LOG_DIR}" || die "创建目录失败" + +INCLUDE_DIRS=() +if [[ -f "${PTO_ISA_PATH}/include/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/include") +fi +if [[ -d "${PTO_ISA_PATH}/tests/common" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/tests/common") +fi +if [[ -f "${PTO_ISA_PATH}/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}") +fi +[[ ${#INCLUDE_DIRS[@]} -gt 0 ]] || die "未找到 pto/pto-inst.hpp,请检查 --pto-isa-path" + +if [[ -n "${ASCEND_HOME_PATH:-}" && -d "${ASCEND_HOME_PATH}/include" ]]; then + INCLUDE_DIRS+=("${ASCEND_HOME_PATH}/include") +fi +ASCEND_DRIVER_PATH="${ASCEND_DRIVER_PATH:-/usr/local/Ascend/driver}" +if [[ -d "${ASCEND_DRIVER_PATH}/kernel/inc" ]]; then + INCLUDE_DIRS+=("${ASCEND_DRIVER_PATH}/kernel/inc") +fi + +DEFAULT_ARGS=() +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + DEFAULT_ARGS=( + "-xcce" + "-fenable-matrix" + "--cce-aicore-enable-tl" + "--cce-aicore-only" + "-fPIC" + "-Xhost-start" + "-Xhost-end" + "-mllvm" "-cce-aicore-stack-size=0x8000" + "-mllvm" "-cce-aicore-function-stack-size=0x8000" + "-mllvm" "-cce-aicore-record-overflow=true" + "-mllvm" "-cce-aicore-addr-transform" + "-mllvm" "-cce-aicore-dcci-insert-for-scalar=false" + "--cce-aicore-arch=${AICORE_ARCH}" + "-D${MEM_BASE_DEFINE}" + "-std=c++17" + ) + if [[ "${AICORE_ARCH}" == dav-l310* || "${AICORE_ARCH}" == dav-l311* ]]; then + FILTERED_DEFAULT_ARGS=() + i=0 + while [[ ${i} -lt ${#DEFAULT_ARGS[@]} ]]; do + if [[ "${DEFAULT_ARGS[${i}]}" == "-mllvm" ]] && [[ $((i + 1)) -lt ${#DEFAULT_ARGS[@]} ]] && + [[ "${DEFAULT_ARGS[$((i + 1))]}" == "-cce-aicore-stack-size=0x8000" ]]; then + i=$((i + 2)) + continue + fi + FILTERED_DEFAULT_ARGS+=("${DEFAULT_ARGS[${i}]}") + i=$((i + 1)) + done + DEFAULT_ARGS=("${FILTERED_DEFAULT_ARGS[@]}") + fi +fi + +declare -a CPP_FILES=() +while IFS= read -r -d '' file; do + CPP_FILES+=("${file}") +done < <(find "${SRC_ROOT}" -type f -name "*.cpp" -print0 | sort -z) + +TOTAL_COUNT=${#CPP_FILES[@]} +[[ ${TOTAL_COUNT} -gt 0 ]] || die "未在 ${SRC_ROOT} 下找到 .cpp 文件" + +STATUS_FILE="$(mktemp "${BUILD_ROOT}/compile_status.XXXXXX")" || die "创建状态文件失败" +trap 'rm -f "${STATUS_FILE}"' EXIT + +record_compile_status() { + local status="$1" + local rel_path="$2" + printf '%s\t%s\n' "${status}" "${rel_path}" >>"${STATUS_FILE}" +} + +cleanup_work_dir() { + local work_dir="$1" + [[ -n "${work_dir}" ]] && rm -rf -- "${work_dir}" +} + +get_log_failure_reason() { + local log_path="$1" + local excerpt + + excerpt="$(grep -E -i 'error:|fatal:|undefined reference|undefined symbol|undeclared identifier|exception|traceback|failed' "${log_path}" | tail -n 5 || true)" + if [[ -z "${excerpt}" ]]; then + excerpt="$(tail -n 10 "${log_path}" 2>/dev/null || true)" + fi + printf '%s' "${excerpt}" +} + +find_generated_output() { + local work_dir="$1" + local src_stem="$2" + local candidate + + for candidate in \ + "${work_dir}/${src_stem}.o" \ + "${work_dir}/${src_stem}.S" \ + "${work_dir}/${src_stem}.s"; do + if [[ -f "${candidate}" ]]; then + printf '%s\n' "${candidate}" + return 0 + fi + done + + find "${work_dir}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1 +} + +write_rebuild_cmd() { + local cmd_path="$1" + local asm_path="$2" + local src_stem="$3" + shift 3 + local -a cmd=("$@") + local cmd_text="" + local arg + + for arg in "${cmd[@]}"; do + printf -v cmd_text '%s %q' "${cmd_text}" "${arg}" + done + cmd_text="${cmd_text# }" + + { + echo "#!/usr/bin/env bash" + echo + echo "set -euo pipefail" + echo + printf 'ASM_PATH=%q\n' "${asm_path}" + printf 'SRC_STEM=%q\n' "${src_stem}" + printf 'WORK_ROOT=%q\n' "${BUILD_ROOT}" + echo + echo 'WORK_DIR="$(mktemp -d "${WORK_ROOT}/tmp_rebuild.XXXXXX")"' + echo 'trap '\''rm -rf -- "${WORK_DIR}"'\'' EXIT' + echo + echo 'cd "${WORK_DIR}"' + echo "${cmd_text}" + echo + echo 'GENERATED_FILE=""' + echo 'for candidate in "${WORK_DIR}/${SRC_STEM}.o" "${WORK_DIR}/${SRC_STEM}.S" "${WORK_DIR}/${SRC_STEM}.s"; do' + echo ' if [[ -f "${candidate}" ]]; then' + echo ' GENERATED_FILE="${candidate}"' + echo ' break' + echo ' fi' + echo 'done' + echo + echo 'if [[ -z "${GENERATED_FILE}" ]]; then' + echo ' GENERATED_FILE="$(find "${WORK_DIR}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1)"' + echo 'fi' + echo + echo 'if [[ -z "${GENERATED_FILE}" || ! -f "${GENERATED_FILE}" ]]; then' + echo ' echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" >&2' + echo ' exit 1' + echo 'fi' + echo + echo 'mkdir -p "$(dirname -- "${ASM_PATH}")"' + echo 'mv -f -- "${GENERATED_FILE}" "${ASM_PATH}"' + printf 'echo "已更新: %s"\n' "${asm_path}" + } >"${cmd_path}" || return 1 + + chmod +x "${cmd_path}" +} + +compile_one() { + local src="$1" + local rel_path asm_path log_path cmd_path src_base src_stem work_dir generated_file + local -a cmd=() + + rel_path="${src#"${SRC_ROOT}/"}" + asm_path="${BUILD_ROOT}/${rel_path%.cpp}.S" + log_path="${LOG_DIR}/${rel_path%.cpp}.log" + cmd_path="$(dirname -- "${log_path}")/cmd.sh" + src_base="$(basename -- "${src}")" + src_stem="${src_base%.cpp}" + + mkdir -p "$(dirname -- "${asm_path}")" "$(dirname -- "${log_path}")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + cmd=("${COMPILER}") + if [[ ${#DEFAULT_ARGS[@]} -gt 0 ]]; then + cmd+=("${DEFAULT_ARGS[@]}") + fi + if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + cmd+=("${EXTRA_ARGS[@]}") + fi + local inc + for inc in "${INCLUDE_DIRS[@]}"; do + cmd+=("-I${inc}") + done + cmd+=("-c" "${src}") + + if ! write_rebuild_cmd "${cmd_path}" "${asm_path}" "${src_stem}" "${cmd[@]}"; then + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + echo "[BUILD] ${rel_path}" + work_dir="$(mktemp -d "${BUILD_ROOT}/tmp_compile.XXXXXX")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + if ! (cd "${work_dir}" && "${cmd[@]}") >"${log_path}" 2>&1; then + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + generated_file="$(find_generated_output "${work_dir}" "${src_stem}")" + + if [[ -z "${generated_file}" || ! -f "${generated_file}" ]]; then + { + echo + echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" + echo "[ERROR] 临时目录: ${work_dir}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + if mv -f -- "${generated_file}" "${asm_path}"; then + cleanup_work_dir "${work_dir}" + record_compile_status "OK" "${rel_path}" + else + { + echo + echo "[ERROR] 输出重命名失败: ${generated_file} -> ${asm_path}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + fi +} + +START_TIME="$(date +%s)" + +echo "[INFO] 编译器: ${COMPILER}" +echo "[INFO] 源目录: ${SRC_ROOT}" +echo "[INFO] 产物目录(.S): ${BUILD_ROOT}" +echo "[INFO] 日志目录: ${LOG_DIR}" +echo "[INFO] PTO-ISA: ${PTO_ISA_PATH}" +echo "[INFO] 并行度: ${JOBS}" +echo "[INFO] include: ${INCLUDE_DIRS[*]}" +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + echo "[INFO] 默认参数(来自 generate_testcase.py): ${DEFAULT_ARGS[*]}" +else + echo "[INFO] 默认参数: 已禁用 (--no-default-args)" +fi +if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + echo "[INFO] 额外参数: ${EXTRA_ARGS[*]}" +fi +echo "[INFO] 文件总数: ${TOTAL_COUNT}" +echo + +running_jobs=0 +for src in "${CPP_FILES[@]}"; do + compile_one "${src}" & + running_jobs=$((running_jobs + 1)) + if [[ ${running_jobs} -ge ${JOBS} ]]; then + wait -n + running_jobs=$((running_jobs - 1)) + fi +done + +wait + +SUCCESS_COUNT="$(awk -F'\t' '$1=="OK"{c++} END{print c+0}' "${STATUS_FILE}")" +FAIL_COUNT="$(awk -F'\t' '$1=="FAIL"{c++} END{print c+0}' "${STATUS_FILE}")" + +declare -a FAILED_FILES=() +while IFS= read -r failed; do + [[ -n "${failed}" ]] && FAILED_FILES+=("${failed}") +done < <(awk -F'\t' '$1=="FAIL"{print $2}' "${STATUS_FILE}") + +END_TIME="$(date +%s)" +ELAPSED="$((END_TIME - START_TIME))" + +echo +echo "========== 编译汇总 ==========" +echo "总文件数 : ${TOTAL_COUNT}" +echo "成功数 : ${SUCCESS_COUNT}" +echo "失败数 : ${FAIL_COUNT}" +echo "耗时(秒) : ${ELAPSED}" + +if [[ ${FAIL_COUNT} -gt 0 ]]; then + failure_reason="" + echo + echo "失败文件列表:" + for f in "${FAILED_FILES[@]}"; do + echo " - ${f} (log: ${LOG_DIR}/${f%.cpp}.log)" + failure_reason="$(get_log_failure_reason "${LOG_DIR}/${f%.cpp}.log")" + if [[ -n "${failure_reason}" ]]; then + while IFS= read -r line; do + [[ -n "${line}" ]] || continue + echo " reason: ${line}" + done <<<"${failure_reason}" + fi + done + exit 1 +fi + +echo "[INFO] 全部编译成功" +exit 0 diff --git a/scripts/compile_pto_to_vpto_llvm.sh b/scripts/compile_pto_to_vpto_llvm.sh new file mode 100755 index 000000000..2d15c86b6 --- /dev/null +++ b/scripts/compile_pto_to_vpto_llvm.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PTO_FILE="${1:-}" +OUT_DIR_ARG="${2:-}" + +PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" +PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5}" +VPTO_FLAGS="${VPTO_FLAGS:---pto-backend=vpto --vpto-emit-hivm-llvm}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" +BISHENG_BIN="" +BISHENG_FLAGS="${BISHENG_FLAGS:-}" +LLVM_IR="" +DEVICE_OBJ="" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +on_error() { + local exit_code="$1" + if [[ -n "${LLVM_IR}" && -f "${LLVM_IR}" ]]; then + echo "Retained LLVM IR: ${LLVM_IR}" >&2 + fi + if [[ -n "${DEVICE_OBJ}" ]]; then + echo "Expected device object: ${DEVICE_OBJ}" >&2 + fi + exit "${exit_code}" +} + +trap 'on_error $?' ERR + +usage() { + cat < [output_dir] + +Environment overrides: + PTOAS_BIN path to ptoas + PTOAS_FLAGS default: --pto-arch a5 + VPTO_FLAGS default: --pto-backend=vpto --vpto-emit-hivm-llvm + ASCEND_HOME_PATH default: \$HOME/cann + BISHENG_BIN + BISHENG_FLAGS extra flags passed to bisheng when compiling .ll to .o + AICORE_ARCH default: dav-c310-vec + +Example: + $(basename "$0") test/samples/PyPTOIRParser/paged_attention_example_kernel_online_update.pto +EOF +} + +[[ -n "${PTO_FILE}" ]] || { + usage + exit 1 +} + +[[ "${PTO_FILE}" == *.pto ]] || die "input must be a .pto file: ${PTO_FILE}" +[[ -f "${PTO_FILE}" ]] || die "missing input file: ${PTO_FILE}" + +set +u +source "${ROOT_DIR}/scripts/ptoas_env.sh" +set -u + +if [[ -n "${ASCEND_HOME_PATH}" && -f "${ASCEND_HOME_PATH}/set_env.sh" ]]; then + set +u + source "${ASCEND_HOME_PATH}/set_env.sh" >/dev/null 2>&1 + set -u +fi + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" + +[[ -x "${PTOAS_BIN}" ]] || die "PTOAS_BIN is not executable: ${PTOAS_BIN}" +command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" + +pto_abs="$(cd "$(dirname "${PTO_FILE}")" && pwd)/$(basename "${PTO_FILE}")" +pto_base="$(basename "${PTO_FILE}" .pto)" + +if [[ -n "${OUT_DIR_ARG}" ]]; then + OUT_DIR="${OUT_DIR_ARG}" +else + OUT_DIR="${ROOT_DIR}/build/vpto_quick/${pto_base}" +fi + +mkdir -p "${OUT_DIR}" +OUT_DIR="$(cd "${OUT_DIR}" && pwd)" + +LLVM_IR="${OUT_DIR}/${pto_base}.ll" +DEVICE_OBJ="${OUT_DIR}/${pto_base}.o" + +log "step 1/2: lower PTO to VPTO LLVM IR" +"${PTOAS_BIN}" ${PTOAS_FLAGS} ${VPTO_FLAGS} \ + "${pto_abs}" \ + -o "${LLVM_IR}" + +log "step 2/2: compile LLVM IR to device object" +"${BISHENG_BIN}" \ + --target=hiipu64-hisilicon-cce \ + -march="${AICORE_ARCH}" \ + --cce-aicore-arch="${AICORE_ARCH}" \ + --cce-aicore-only \ + ${BISHENG_FLAGS} \ + -c -x ir "${LLVM_IR}" \ + -o "${DEVICE_OBJ}" + +log "done" +echo "LLVM IR: ${LLVM_IR}" +echo "Device object: ${DEVICE_OBJ}" diff --git a/scripts/ptoas_env.sh b/scripts/ptoas_env.sh new file mode 100644 index 000000000..95dcd9a8d --- /dev/null +++ b/scripts/ptoas_env.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# PTOAS runtime environment bootstrap. +# Usage: +# source scripts/ptoas_env.sh +# +# Optional overrides before sourcing: +# export WORKSPACE_DIR=/path/to/workspace +# export LLVM_BUILD_DIR=/path/to/llvm-project/build-shared +# export PTO_SOURCE_DIR=/path/to/PTOAS +# export PTO_INSTALL_DIR=/path/to/PTOAS/install +# export PTO_PYTHON_BIN=/path/to/python3 + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "This script must be sourced: source scripts/ptoas_env.sh" + exit 1 +fi + +_PTOAS_ENV_SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +_PTOAS_REPO_DIR="$(cd -- "${_PTOAS_ENV_SCRIPT_DIR}/.." && pwd)" + +# Default layout: +# / +# ├── PTOAS/ +# └── llvm-project/ +export PTO_SOURCE_DIR="${PTO_SOURCE_DIR:-${_PTOAS_REPO_DIR}}" +export WORKSPACE_DIR="${WORKSPACE_DIR:-$(cd -- "${PTO_SOURCE_DIR}/.." && pwd)}" +export LLVM_SOURCE_DIR="${LLVM_SOURCE_DIR:-${WORKSPACE_DIR}/llvm-project}" +export LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-${LLVM_SOURCE_DIR}/build-shared}" +export PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" +export PTO_ISA_PATH="${PTO_ISA_PATH:-${WORKSPACE_DIR}/pto-isa}" +export ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +export MLIR_PYTHON_ROOT="${MLIR_PYTHON_ROOT:-${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core}" +export PTO_PYTHON_ROOT="${PTO_PYTHON_ROOT:-${PTO_INSTALL_DIR}}" +export PTO_PYTHON_BUILD_ROOT="${PTO_PYTHON_BUILD_ROOT:-${PTO_SOURCE_DIR}/build/python}" +export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) +export PTOAS_FLAGS="${PTOAS_FLAGS:-}" +export PTOAS_OUT_DIR=$PTO_SOURCE_DIR/build/output + +_ptoas_prepend_path() { + local var_name="$1" + local value="$2" + local current="${!var_name:-}" + if [[ -z "${value}" ]]; then + return 0 + fi + if [[ ! -e "${value}" ]]; then + return 0 + fi + if [[ ":${current}:" == *":${value}:"* ]]; then + return 0 + fi + if [[ -z "${current}" ]]; then + printf -v "${var_name}" '%s' "${value}" + else + printf -v "${var_name}" '%s:%s' "${value}" "${current}" + fi + export "${var_name}" +} + +_ptoas_prepend_path PYTHONPATH "${MLIR_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_BUILD_ROOT}" + +_ptoas_prepend_path LD_LIBRARY_PATH "${LLVM_BUILD_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_INSTALL_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_SOURCE_DIR}/build/lib" + +_ptoas_prepend_path PATH "${PTO_SOURCE_DIR}/build/tools/ptoas" + +if [[ -n "${PTO_PYTHON_BIN:-}" && -x "${PTO_PYTHON_BIN}" ]]; then + alias ptoas-python="${PTO_PYTHON_BIN}" +fi + +echo "[ptoas_env] PTO_SOURCE_DIR=${PTO_SOURCE_DIR}" +echo "[ptoas_env] LLVM_BUILD_DIR=${LLVM_BUILD_DIR}" +echo "[ptoas_env] PTO_INSTALL_DIR=${PTO_INSTALL_DIR}" +echo "[ptoas_env] PTO_ISA_PATH=${PTO_ISA_PATH}" +echo "[ptoas_env] ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" +echo "[ptoas_env] PATH/PYTHONPATH/LD_LIBRARY_PATH updated" + +unset _PTOAS_ENV_SCRIPT_DIR +unset _PTOAS_REPO_DIR diff --git a/test/dsl/abs.py b/test/dsl/abs.py new file mode 100644 index 000000000..7c67e5959 --- /dev/null +++ b/test/dsl/abs.py @@ -0,0 +1,34 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_kernel_2d") +def abs_kernel_2d(inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm")): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.vecscope(): + remaining: pto.i32 = 1024 + for offset in range(0, 1024, 64): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(ub_in, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, ub_out, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/strict_vecscope.py b/test/dsl/strict_vecscope.py new file mode 100644 index 000000000..e882df3d8 --- /dev/null +++ b/test/dsl/strict_vecscope.py @@ -0,0 +1,42 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_strict_vecscope_kernel_2d") +def abs_strict_vecscope_kernel_2d( + inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm") +): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.strict_vecscope(ub_in, ub_out, 0, 1024, 64, 1024) as ( + src, + dst, + lb, + ub, + step, + remaining, + ): + for offset in range(lb, ub, step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(src, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, dst, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_strict_vecscope_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/template_abs.py b/test/dsl/template_abs.py new file mode 100644 index 000000000..87b330e32 --- /dev/null +++ b/test/dsl/template_abs.py @@ -0,0 +1,48 @@ +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="template_abs_kernel") +def template_abs_kernel(src: pto.Tile, dst: pto.Tile): + total = src.shape[0] * src.shape[1] + step = 256 // src.ub_ptr.elem_bytes + + with pto.strict_vecscope(src.ub_ptr, dst.ub_ptr, 0, total, step, total) as ( + vin, + vout, + lb, + ub, + vec_step, + remaining, + ): + for offset in range(lb, ub, vec_step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(vin, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, vout, offset, mask) + + +template_abs_kernel_f32 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), +) + +template_abs_kernel_f16 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), +) + + +if __name__ == "__main__": + print(template_abs_kernel_f32.mlir_text(), end="") diff --git a/test/lit.cfg.py b/test/lit.cfg.py new file mode 100644 index 000000000..95e17569a --- /dev/null +++ b/test/lit.cfg.py @@ -0,0 +1,85 @@ +import os +import lit.formats + +config.name = "PTOAS" +config.test_format = lit.formats.ShTest(execute_external=True) + +# Keep discovery focused on lit-style tests. +config.suffixes = [".mlir", ".pto"] +config.excludes = [ + "CMakeLists.txt", + "README.md", + "lit.cfg.py", + "resources", +] + +config.test_source_root = os.path.dirname(__file__) + + +def _resolve_build_root(): + env_build_dir = os.environ.get("PTOAS_BUILD_DIR") + if env_build_dir: + return os.path.abspath(env_build_dir) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + return os.path.join(repo_root, "build") + + +build_root = _resolve_build_root() +config.test_exec_root = os.path.join(build_root, "test") +os.makedirs(config.test_exec_root, exist_ok=True) + + +def _resolve_llvm_bin_dir(): + env_build_dir = os.environ.get("LLVM_BUILD_DIR") + candidates = [] + if env_build_dir: + candidates.append(os.path.join(os.path.abspath(env_build_dir), "bin")) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + candidates.append( + os.path.abspath( + os.path.join(repo_root, "..", "llvm-project", "build-shared", "bin") + ) + ) + + for candidate in candidates: + if os.path.isdir(candidate): + return candidate + return "" + + +def _resolve_ptoas_bin(): + env_bin = os.environ.get("PTOAS_BIN") + if env_bin: + return env_bin + + candidate = os.path.join(build_root, "tools", "ptoas", "ptoas") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + + return "ptoas" + + +def _prepend_path(path_var, entry): + if not entry: + return path_var + if not path_var: + return entry + return entry + os.pathsep + path_var + + +ptoas_bin = _resolve_ptoas_bin() +ptoas_dir = os.path.dirname(ptoas_bin) if os.path.isabs(ptoas_bin) else "" +llvm_bin_dir = _resolve_llvm_bin_dir() + +path_env = config.environment.get("PATH", os.environ.get("PATH", "")) +if llvm_bin_dir: + path_env = _prepend_path(path_env, llvm_bin_dir) +if ptoas_dir: + path_env = _prepend_path(path_env, ptoas_dir) +config.environment["PATH"] = path_env + +# Keep RUN lines using bare `ptoas` stable regardless of shell cwd. +if os.path.isabs(ptoas_bin): + config.substitutions.append(("ptoas", ptoas_bin)) diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py new file mode 100644 index 000000000..e6af92b4a --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, 16x128-f32, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(abs_diff)) + print( + f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " + f"(golden={float(golden[idx])}, out={float(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + ok = True + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v5.bin", "v5.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v6.bin", "v6.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v7.bin", "v7.bin", np.float32, 1e-4) and ok + if not ok: + print("[ERROR] compare failed") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/golden.py b/test/vpto/cases/kernels/online-softmax-update/golden.py new file mode 100644 index 000000000..ea41425eb --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 24 +COLS = 128 +SEED = 19 +SEQ = 73 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + seq = SEQ + oldmax = rng.uniform(-3.0, 1.5, size=(ROWS,)).astype(np.float32) + oldsum = rng.uniform(0.5, 4.0, size=(ROWS,)).astype(np.float32) + qk = rng.normal(loc=0.0, scale=1.5, size=(ROWS, COLS)).astype(np.float32) + + qk_active = qk[:, :seq] + qk_rowmax = np.max(qk_active, axis=1) + newmax = np.maximum(qk_rowmax, oldmax) + tmp_active = np.exp(qk_active - newmax[:, None], dtype=np.float32) + cursum = np.sum(tmp_active, axis=1, dtype=np.float32) + raw_expmax = np.exp(oldmax - newmax, dtype=np.float32) + newsum = raw_expmax * oldsum + cursum + expmax = (raw_expmax * oldsum) / newsum + out = np.zeros((ROWS, COLS), dtype=np.float32) + out[:, :seq] = tmp_active / newsum[:, None] + + zeros_state = np.zeros((ROWS,), dtype=np.float32) + zeros_out = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + oldmax.tofile(output_dir / "v1.bin") + oldsum.tofile(output_dir / "v2.bin") + qk.reshape(-1).tofile(output_dir / "v3.bin") + zeros_state.tofile(output_dir / "v4.bin") + zeros_state.tofile(output_dir / "v5.bin") + zeros_state.tofile(output_dir / "v6.bin") + zeros_out.reshape(-1).tofile(output_dir / "v7.bin") + np.array([seq], dtype=np.int32).tofile(output_dir / "v8.bin") + np.array([ROWS], dtype=np.int32).tofile(output_dir / "v9.bin") + newmax.tofile(output_dir / "golden_v4.bin") + newsum.tofile(output_dir / "golden_v5.bin") + expmax.tofile(output_dir / "golden_v6.bin") + out.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v7.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto new file mode 100644 index 000000000..9d49bc6cb --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -0,0 +1,164 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %false = arith.constant false + + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %qk_base = arith.muli %row_base, %c128 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %row_count_i64 = arith.extui %row_count_i32 : i32 to i64 + %gm_oldmax = pto.addptr %arg0, %row_base : !pto.ptr -> !pto.ptr + %gm_oldsum = pto.addptr %arg1, %row_base : !pto.ptr -> !pto.ptr + %gm_qk = pto.addptr %arg2, %qk_base : !pto.ptr -> !pto.ptr + %gm_qk_hi = pto.addptr %gm_qk, %c64 : !pto.ptr -> !pto.ptr + %gm_newmax = pto.addptr %arg3, %row_base : !pto.ptr -> !pto.ptr + %gm_newsum = pto.addptr %arg4, %row_base : !pto.ptr -> !pto.ptr + %gm_expmax = pto.addptr %arg5, %row_base : !pto.ptr -> !pto.ptr + %gm_out = pto.addptr %arg6, %qk_base : !pto.ptr -> !pto.ptr + + %ub_oldmax = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_oldsum = pto.castptr %c128_i64 : i64 -> !pto.ptr + %ub_qk = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_qk_hi = pto.addptr %ub_qk, %c64 : !pto.ptr -> !pto.ptr + %ub_out = pto.castptr %c8448_i64 : i64 -> !pto.ptr + %ub_newmax = pto.castptr %c16640_i64 : i64 -> !pto.ptr + %ub_newsum = pto.castptr %c16768_i64 : i64 -> !pto.ptr + %ub_expmax = pto.castptr %c16896_i64 : i64 -> !pto.ptr + + scf.if %has_rows { + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_oldmax, %ub_oldmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_oldsum, %ub_oldsum, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_qk, %ub_qk, %c0_i64, %row_count_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %gm_qk_hi, %ub_qk_hi, %c0_i64, %row_count_i64, %c256_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdiff %running_max, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdiff %vec, %merged_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdiff %oldmax_bc, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + %zero = pto.vsub %final_max, %final_max, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_base = arith.addi %row_qk, %chunk : index + pto.vsts %zero, %ub_out[%chunk_base], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdiff %vec, %final_max, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_newmax, %gm_newmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_newsum, %gm_newsum, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_expmax, %gm_expmax, %c0_i64, %c1_i64, %c32_i64, %c0_i64, %c32_i64, %c32_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %gm_out, %c0_i64, %row_count_i64, %c512_i64, %c0_i64, %c512_i64, %c512_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/kernels/online-softmax-update/launch.cpp b/test/vpto/cases/kernels/online-softmax-update/launch.cpp new file mode 100644 index 000000000..5cf6c4e2f --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/launch.cpp @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#include +#include + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d( + __gm__ float *v1, __gm__ float *v2, __gm__ float *v3, + __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, + __gm__ float *v7, int32_t v8, int32_t v9); + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream) { + const int32_t blockRows = 8; + const int32_t blocks = (v9 + blockRows - 1) / blockRows; + online_softmax_update_kernel_2d<<>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3, + (__gm__ float *)v4, (__gm__ float *)v5, (__gm__ float *)v6, + (__gm__ float *)v7, v8, v9); +} diff --git a/test/vpto/cases/kernels/online-softmax-update/main.cpp b/test/vpto/cases/kernels/online-softmax-update/main.cpp new file mode 100644 index 000000000..6282f13a8 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/main.cpp @@ -0,0 +1,153 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream); + +int main() { + constexpr size_t elemCountSeq = 1; + constexpr size_t elemCountRows = 1; + size_t fileSizeSeq = elemCountSeq * sizeof(int32_t); + size_t fileSizeRows = elemCountRows * sizeof(int32_t); + size_t elemCountState = 0; + size_t elemCountOut = 0; + size_t fileSizeState = 0; + size_t fileSizeOut = 0; + float *v1Host = nullptr, *v2Host = nullptr, *v3Host = nullptr; + float *v4Host = nullptr, *v5Host = nullptr, *v6Host = nullptr; + float *v7Host = nullptr; + float *v1Device = nullptr, *v2Device = nullptr, *v3Device = nullptr; + float *v4Device = nullptr, *v5Device = nullptr, *v6Device = nullptr; + float *v7Device = nullptr; + int32_t v8Host = 0, v9Host = 0; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ReadFile("./v8.bin", fileSizeSeq, &v8Host, fileSizeSeq); + ReadFile("./v9.bin", fileSizeRows, &v9Host, fileSizeRows); + + elemCountState = static_cast(v9Host); + elemCountOut = static_cast(v9Host) * 128; + fileSizeState = elemCountState * sizeof(float); + fileSizeOut = elemCountOut * sizeof(float); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v5Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v6Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v7Host), fileSizeOut)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v5Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v6Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v7Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeState, v1Host, fileSizeState); + ReadFile("./v2.bin", fileSizeState, v2Host, fileSizeState); + ReadFile("./v3.bin", fileSizeOut, v3Host, fileSizeOut); + ReadFile("./v4.bin", fileSizeState, v4Host, fileSizeState); + ReadFile("./v5.bin", fileSizeState, v5Host, fileSizeState); + ReadFile("./v6.bin", fileSizeState, v6Host, fileSizeState); + ReadFile("./v7.bin", fileSizeOut, v7Host, fileSizeOut); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeState, v1Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeState, v2Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeOut, v3Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSizeState, v4Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v5Device, fileSizeState, v5Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v6Device, fileSizeState, v6Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v7Device, fileSizeOut, v7Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchOnline_softmax_update_kernel_2d(v1Device, v2Device, v3Device, + v4Device, v5Device, v6Device, + v7Device, v8Host, v9Host, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSizeState, v4Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v5Host, fileSizeState, v5Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v6Host, fileSizeState, v6Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v7Host, fileSizeOut, v7Device, fileSizeOut, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", v4Host, fileSizeState); + WriteFile("./v5.bin", v5Host, fileSizeState); + WriteFile("./v6.bin", v6Host, fileSizeState); + WriteFile("./v7.bin", v7Host, fileSizeOut); + +cleanup: + aclrtFree(v1Device); aclrtFree(v2Device); aclrtFree(v3Device); + aclrtFree(v4Device); aclrtFree(v5Device); aclrtFree(v6Device); aclrtFree(v7Device); + aclrtFreeHost(v1Host); aclrtFreeHost(v2Host); aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); aclrtFreeHost(v5Host); aclrtFreeHost(v6Host); aclrtFreeHost(v7Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + + return rc; +} diff --git a/test/vpto/cases/kernels/online-softmax-update/stub.cpp b/test/vpto/cases/kernels/online-softmax-update/stub.cpp new file mode 100644 index 000000000..003519801 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/stub.cpp @@ -0,0 +1,23 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.copy_gm_to_ubuf, pto.copy_ubuf_to_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdiff, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d( + __gm__ float *v1, __gm__ float *v2, __gm__ float *v3, + __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, + __gm__ float *v7, int32_t v8, int32_t v9) { + (void)v1; (void)v2; (void)v3; (void)v4; + (void)v5; (void)v6; (void)v7; (void)v8; (void)v9; +} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e0c49c4cd..b60019675 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -7,9 +7,12 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -42,6 +45,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/MemoryBuffer.h" #include using namespace mlir; @@ -205,12 +209,87 @@ static llvm::cl::opt ptoBuildLevel( llvm::cl::value_desc("level1|level2|level3"), llvm::cl::init("level2")); +static llvm::cl::opt ptoBackend( + "pto-backend", + llvm::cl::desc("Final PTOAS backend: emitc or vpto (default: emitc)"), + llvm::cl::value_desc("emitc|vpto"), llvm::cl::init("emitc")); + +static llvm::cl::opt emitVPTO( + "emit-vpto", + llvm::cl::desc("Write final post-pass VPTO IR to -o"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoPrintIR( + "vpto-print-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoLoweringStrategy( + "vpto-lowering-strategy", + llvm::cl::desc("VPTO vector lowering strategy: post-update or no-post-update"), + llvm::cl::value_desc("post-update|no-post-update"), + llvm::cl::init("post-update")); + +static llvm::cl::opt dumpVPTOIR( + "dump-vpto-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoPrintSeamIR( + "pto-print-seam-ir", + llvm::cl::desc("Print shared pre-backend seam IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoSeamIRFile( + "pto-seam-ir-file", + llvm::cl::desc("Write shared pre-backend seam IR to a file"), + llvm::cl::value_desc("path"), + llvm::cl::init("")); + +static llvm::cl::opt vptoPrintIntrinsics( + "vpto-print-intrinsics", + llvm::cl::desc("Print VPTO intrinsic selection decisions to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoEmitHIVMOfficialLLVM( + "vpto-emit-hivm-llvm", + llvm::cl::desc("After lowering to VPTO IR, emit textual LLVM/HIVM via " + "the official LLVM dialect export path"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoEmitHIVMOfficialBitcode( + "vpto-emit-hivm-bc", + llvm::cl::desc("After lowering to VPTO IR, emit LLVM bitcode via the " + "official LLVM dialect export path"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoAllowUnresolved( + "vpto-allow-unresolved", + llvm::cl::desc("Emit explicit unresolved VPTO comments instead of failing"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoUnresolvedReport( + "vpto-unresolved-report", + llvm::cl::desc("Write unresolved VPTO mappings to a sidecar report"), + llvm::cl::value_desc("path"), llvm::cl::init("")); + +static llvm::cl::opt hivmUnresolvedReport( + "hivm-unresolved-report", + llvm::cl::desc("Write unresolved HIVM mappings to a sidecar report"), + llvm::cl::value_desc("path"), + llvm::cl::init("")); + enum class PTOBuildLevel { Level1, Level2, Level3, }; +enum class PTOBackend { + EmitC, + VPTO, +}; + static PTOBuildLevel defaultBuildLevel() { return PTOBuildLevel::Level2; } @@ -256,6 +335,94 @@ static bool parseAutoSyncTailHint(llvm::StringRef hintStr, std::string &normaliz return false; } +static bool parseBackend(llvm::StringRef backendStr, PTOBackend &out) { + std::string s = backendStr.str(); + for (char &c : s) + c = static_cast(std::tolower(static_cast(c))); + if (s == "emitc") { + out = PTOBackend::EmitC; + return true; + } + if (s == "vpto") { + out = PTOBackend::VPTO; + return true; + } + return false; +} + +static LogicalResult emitSharedPreBackendSeamIR(ModuleOp module, + llvm::StringRef outputPath) { + if (outputPath.empty()) + return success(); + + if (outputPath == "-") { + module->print(llvm::outs()); + llvm::outs() << "\n"; + llvm::outs().flush(); + return success(); + } + + std::error_code ec; + llvm::ToolOutputFile outputFile(outputPath, ec, llvm::sys::fs::OF_None); + if (ec) { + llvm::errs() << "Error: failed to open seam IR file '" << outputPath + << "': " << ec.message() << "\n"; + return failure(); + } + + module->print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return success(); +} + +static bool containsVPTOOpPrefix(llvm::StringRef line, + llvm::StringRef opPrefix) { + size_t searchFrom = 0; + while (searchFrom < line.size()) { + size_t pos = line.find(opPrefix, searchFrom); + if (pos == llvm::StringRef::npos) + return false; + + if (pos == 0) + return true; + + unsigned char before = static_cast(line[pos - 1]); + if (std::isspace(before) || before == '(' || before == '=' || + before == ',') + return true; + + searchFrom = pos + 1; + } + return false; +} + +static bool containsVPTOIR(llvm::StringRef input) { + llvm::StringRef rest = input; + while (!rest.empty()) { + auto split = rest.split('\n'); + llvm::StringRef line = split.first.trim(); + if (!line.starts_with("//") && + (line.contains("!pto.vec<") || line.contains("!pto.mask") || + line.contains("!pto.align") || + containsVPTOOpPrefix(line, "pto.copy_") || + containsVPTOOpPrefix(line, "pto.set_loop") || + containsVPTOOpPrefix(line, "pto.v") || + containsVPTOOpPrefix(line, "pto.plt_") || + containsVPTOOpPrefix(line, "pto.pset_") || + containsVPTOOpPrefix(line, "pto.psts") || + containsVPTOOpPrefix(line, "pto.pdintlv_") || + containsVPTOOpPrefix(line, "pto.set_flag") || + containsVPTOOpPrefix(line, "pto.wait_flag") || + containsVPTOOpPrefix(line, "pto.pipe_barrier") || + containsVPTOOpPrefix(line, "pto.get_buf") || + containsVPTOOpPrefix(line, "pto.rls_buf"))) + return true; + rest = split.second; + } + return false; +} + // -------------------------------------------------------------------------- // Post-process C++ output: rewrite marker calls into Tile member calls. // @@ -928,6 +1095,85 @@ static void rewriteScalarConstantDecls(std::string &cpp) { cpp.swap(out); } +static LogicalResult prepareVPTOForEmission(ModuleOp module) { + if (failed(convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) { + llvm::errs() << "Error: VPTO emission boundary canonicalization failed.\n"; + return failure(); + } + + PassManager prepPM(module->getContext()); + prepPM.enableVerifier(); + prepPM.addNestedPass(createPTOVPTOExpandBridgeOpsPass()); + prepPM.addPass(createCSEPass()); + prepPM.addPass(pto::createPTOValidateVPTOEmissionIRPass()); + if (failed(prepPM.run(module))) { + llvm::errs() << "Error: VPTO emission preparation failed.\n"; + return failure(); + } + + return success(); +} + +static LogicalResult lowerPTOToVPTOBackend(ModuleOp module) { + PassManager backendPM(module.getContext()); + backendPM.addPass(pto::createLowerPTOToVPTOPass()); + backendPM.addPass(mlir::createCSEPass()); + if (failed(backendPM.run(module))) { + llvm::errs() << "Error: backend lowering pass execution failed.\n"; + return failure(); + } + return success(); +} + +static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { + pto::VPTOEmissionOptions options; + options.dumpVPTOIR = false; + options.printIntrinsicSelections = vptoPrintIntrinsics; + options.allowUnresolved = vptoAllowUnresolved; + options.unresolvedReportPath = + !hivmUnresolvedReport.empty() ? hivmUnresolvedReport : vptoUnresolvedReport; + options.targetTriple = "hiipu64-hisilicon-cce"; + options.march = "dav-c310-vec"; + options.aicoreArch = "dav-c310-vec"; + options.defaultTargetCPU = "dav-c310-vec"; + options.defaultTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-vec"; + return options; +} + +static int emitPreparedVPTOBackendResult(ModuleOp module, + llvm::ToolOutputFile &outputFile) { + if (emitVPTO || (!vptoEmitHIVMOfficialLLVM && !vptoEmitHIVMOfficialBitcode)) { + module.print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return 0; + } + + pto::VPTOEmissionOptions options = buildVPTOEmissionOptions(); + LogicalResult emissionStatus = + vptoEmitHIVMOfficialBitcode + ? pto::translateVPTOModuleToLLVMBitcode(module, outputFile.os(), + options, llvm::errs()) + : pto::translateVPTOModuleToLLVMText(module, outputFile.os(), + options, llvm::errs()); + if (failed(emissionStatus)) { + llvm::errs() << "Error: Failed to emit VPTO text.\n"; + return 1; + } + outputFile.keep(); + return 0; +} + +static int emitVPTOBackendResult(ModuleOp module, + llvm::ToolOutputFile &outputFile) { + if (failed(prepareVPTOForEmission(module))) + return 1; + return emitPreparedVPTOBackendResult(module, outputFile); +} + int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); @@ -963,6 +1209,36 @@ int main(int argc, char **argv) { // Parse command line options llvm::cl::ParseCommandLineOptions(argc, argv, "PTO Assembler (ptoas)\n"); + PTOBackend effectiveBackend = PTOBackend::EmitC; + if (!parseBackend(ptoBackend, effectiveBackend)) { + llvm::errs() << "Error: invalid --pto-backend='" << ptoBackend + << "'. Expected 'emitc' or 'vpto'.\n"; + return 1; + } + + if (vptoEmitHIVMOfficialLLVM && vptoEmitHIVMOfficialBitcode) { + llvm::errs() << "Error: --vpto-emit-hivm-llvm and --vpto-emit-hivm-bc " + "cannot be used together.\n"; + return 1; + } + + if (emitVPTO && + (vptoEmitHIVMOfficialLLVM || vptoEmitHIVMOfficialBitcode)) { + llvm::errs() << "Error: --emit-vpto cannot be used together with HIVM " + "emission flags.\n"; + return 1; + } + + if (effectiveBackend != PTOBackend::VPTO && + (vptoEmitHIVMOfficialLLVM || vptoEmitHIVMOfficialBitcode || emitVPTO || + vptoPrintIntrinsics || vptoAllowUnresolved || + !vptoUnresolvedReport.empty() || !hivmUnresolvedReport.empty() || + ptoPrintSeamIR || !ptoSeamIRFile.empty())) { + llvm::errs() << "Error: VPTO-specific flags require " + "--pto-backend=vpto.\n"; + return 1; + } + // Read whole input first (so we can auto-detect .ptobc by magic). auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (!fileOrErr) { @@ -987,6 +1263,8 @@ int main(int argc, char **argv) { OwningOpRef module; llvm::StringRef buf = (*fileOrErr)->getBuffer(); const bool isPTOBC = (buf.size() >= 6 && std::memcmp(buf.data(), "PTOBC\0", 6) == 0); + const bool inputIsVPTOIR = containsVPTOIR(buf); + auto normalizeArch = [](llvm::StringRef archValue) { std::string normalized = archValue.str(); for (char &c : normalized) @@ -1130,6 +1408,16 @@ int main(int argc, char **argv) { return 1; } + if (effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR) { + if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { + llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " + "the input is already VPTO IR.\n"; + return 1; + } + + return emitVPTOBackendResult(*module, outputFile); + } + // Main PassManager PassManager pm(&context); @@ -1157,6 +1445,28 @@ int main(int argc, char **argv) { pm.addNestedPass(pto::createPTOInsertSyncPass()); pm.addPass(createCSEPass()); + + module->getOperation()->setAttr("pto.target_arch", + mlir::StringAttr::get(&context, arch)); + + if (effectiveBackend == PTOBackend::VPTO) { + if (failed(pm.run(*module))) { + llvm::errs() << "Error: Pass execution failed.\n"; + return 1; + } + + if (ptoPrintSeamIR) { + module->print(llvm::errs()); + llvm::errs() << "\n"; + } + if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) + return 1; + + if (failed(lowerPTOToVPTOBackend(*module))) + return 1; + return emitVPTOBackendResult(*module, outputFile); + } + if (arch == "a3") { pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); } else {