Skip to content

feat(ir): add tile.mscatter op for per-element scatter-store to GM#936

Open
Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Little-oil:issue-921-add-mscatter-op
Open

feat(ir): add tile.mscatter op for per-element scatter-store to GM#936
Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Little-oil:issue-921-add-mscatter-op

Conversation

@Little-oil
Copy link
Copy Markdown
Contributor

Summary

Add tile.mscatter operation that maps to the PTOAS pto.mscatter instruction for per-element scatter-store from UB tile to GM tensor:

output_tensor[idx[i, j]] = src[i, j]

Changes

  • C++ op registration (src/ir/op/tile_ops/memory.cpp): tile.mscatter with type deduction validating src (FP16/FP32/INT16/INT32), idx (INT32, same rank as src), and output_tensor (TensorType, same dtype as src)
  • PTO codegen (src/backend/common/pto_ops_common.cpp): Emits pto.partition_view + pto.mscatter ins(src, idx) outs(pview) with row_major layout constraints on inputs
  • Python IR wrapper (python/pypto/ir/op/tile_ops.py): tile.mscatter(src, idx, output_tensor)
  • Python DSL wrapper (python/pypto/language/op/tile_ops.py): pl.mscatter(src_tile, idx_tile, out_tensor) exported at top-level pl namespace
  • Unit tests (tests/ut/ir/operators/test_tile_ops.py): 7 tests covering basic usage (FP32/FP16), error paths (wrong dtype, rank mismatch, arg count)
  • ST runtime tests (tests/st/runtime/test_mscatter.py): Comprehensive test matrix (FP32/FP16/INT32, 8x32/16x64, sequential/reversed/random/strided indices) — all skipped pending PTOAS NPU implementation

Note on ST Tests

All ST tests are marked pytest.mark.skip because PTOAS currently lacks a real NPU implementation for pto.mscatter — it falls back to TSTORE which ignores the index tile entirely (PTOToEmitC.cpp:4402). The PyPTO-side IR generation is verified correct via unit tests; enable ST tests once PTOAS ships proper MSCATTER lowering.

Testing

  • Unit tests pass (7/7)
  • Full test suite passes (3402 passed, 0 failed)
  • Code review completed
  • Pre-commit hooks pass (clang-format, cpplint, ruff, pyright)

Fixes #921

…w-native-sys#921)

Add tile.mscatter operation mapping to pto.mscatter instruction:
  mem[idx[i, j]] = src[i, j]

- C++ op registration with type deduction and validation
- PTO codegen emitting partition_view + pto.mscatter
- Python IR and DSL wrappers with pl.mscatter export
- Unit tests covering basic usage and error paths
- ST runtime tests (skipped: PTOAS lacks NPU mscatter impl)
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

A new scatter memory operation tile.mscatter is implemented across the compiler stack, from IR-level operator definition through Python DSL API to backend codegen lowering. The operation validates operand types and counts, deduces output type from input tensors, and lowers to PTO partition_view + mscatter instructions in the backend.

Changes

Cohort / File(s) Summary
Python API Layer
python/pypto/ir/op/tile_ops.py, python/pypto/language/op/tile_ops.py
Added IR-level mscatter(src, idx, output_tensor, span=None) and DSL-level mscatter(src: Tile, idx: Tile, output_tensor: Tensor) functions wrapping the new operation.
Module Exports
python/pypto/language/__init__.py, python/pypto/language/op/__init__.py
Re-exported and promoted mscatter symbol into public module APIs for language and language.op namespaces.
IR Type System
src/ir/op/tile_ops/memory.cpp
Added DeduceTileMscatterType validator checking argument counts, types (FP16/FP32/INT16/INT32 for src, INT32 for idx), and rank consistency, then registered tile.mscatter operator with memory specs.
Backend Codegen
src/backend/common/pto_ops_common.cpp
Implemented custom lowering from tile.mscatter to pto.partition_view (full tensor partition with zero offsets) followed by pto.mscatter with src/idx inputs; removed from simple ops table and registered with row-major layout constraints on src and idx.
Test Coverage
tests/ut/ir/operators/test_tile_ops.py, tests/st/runtime/test_mscatter.py
Added IR-level unit tests validating type deduction and error conditions, plus runtime tests covering multiple dtypes (FP32/FP16/INT32) and tile shapes (8x32, 16x64) with various index patterns; runtime tests currently skipped pending backend completion.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant DSL as DSL Layer<br/>(pl.mscatter)
    participant IR as IR Layer<br/>(tile.mscatter)
    participant TypeSystem as Type System<br/>(DeduceTileMscatterType)
    participant Backend as Backend Codegen
    participant PTO as PTO Lowering

    User->>DSL: Call pl.mscatter(src_tile, idx_tile, output)
    DSL->>IR: Invoke _ir_ops.mscatter(src, idx, output)
    IR->>TypeSystem: Deduce output tensor type
    TypeSystem->>TypeSystem: Validate src dtype∈{FP16,FP32,INT16,INT32}<br/>Validate idx dtype==INT32<br/>Validate idx.rank==src.rank<br/>Validate output.dtype==src.dtype
    TypeSystem-->>IR: Return output_tensor type
    IR-->>DSL: Return Tensor(call_expr)
    DSL-->>User: Return scattered result tensor
    
    Note over Backend: During Compilation
    IR->>Backend: Process tile.mscatter call<br/>(src, idx, output_tensor)
    Backend->>PTO: Emit pto.partition_view<br/>(output_tensor)
    PTO-->>Backend: Return partition view
    Backend->>PTO: Emit pto.mscatter<br/>(src, idx, partition_view)
    PTO-->>Backend: Complete scatter operation
    Backend-->>IR: Update result to reference<br/>output_tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123

Poem

🐰 A scatter-op hops into view,
Through tiles and types it journeys true,
From DSL down to PTO's call,
Three arguments dance through it all—
Indices guide the data's flight,
As mscatter fills the output's might! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a new tile.mscatter operation for scatter-store functionality to global memory.
Description check ✅ Passed The description is well-detailed, providing context about the new operation, implementation details across C++/Python layers, testing status, and notes on pending PTOAS support.
Linked Issues check ✅ Passed The PR comprehensively addresses all requirements from issue #921: implements C++ op registration, PTO codegen, Python IR/DSL wrappers, unit tests, and runtime tests with proper skip markers pending PTOAS implementation.
Out of Scope Changes check ✅ Passed All changes are directly scoped to the objective of adding tile.mscatter support: IR registration, codegen, Python APIs, and corresponding tests. No unrelated modifications are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/st/runtime/test_mscatter.py (1)

249-572: Consider adding INT16 test coverage for completeness.

The IR definition in memory.cpp lists INT16 as a supported dtype for tile.mscatter, but the test matrix only covers FP32, FP16, and INT32. Adding an INT16 test case would complete the dtype coverage.

Since these tests are currently skipped pending PTOAS support, this can be addressed later.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/runtime/test_mscatter.py` around lines 249 - 572, Add INT16 test
coverage by creating test classes mirroring the INT32 cases (e.g.,
MscatterINT16SeqTestCase, MscatterINT16RevTestCase,
MscatterINT16RandPermTestCase and a larger MscatterINT16_16x64RandPermTestCase)
that follow the pattern in MscatterINT32SeqTestCase and MscatterFP16/FP32
classes: use DataType.INT16 in define_tensors with the same init helpers
(_init_randint_8x32, _init_sequential_8x32, _init_reversed_8x32,
_init_random_perm_16x64, etc.), return the corresponding program names
(MscatterINT16_8x32Program, MscatterINT16_16x64Program) from get_program, and
implement compute_expected to create a torch.zeros(..., dtype=torch.int16) and
assign out[tensors["idx_tensor"].flatten().long()] =
tensors["src_tensor"].flatten(); keep the tests marked/skipped consistent with
existing PTOAS-skipped tests.
tests/ut/ir/operators/test_tile_ops.py (1)

2009-2142: Consider adding INT16/INT32 happy-path cases for tile.mscatter.

Current positive tests cover FP32/FP16 only, while the op contract also allows INT16/INT32. Adding those two cases would lock in the full supported dtype matrix.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/ir/operators/test_tile_ops.py` around lines 2009 - 2142, Add two
positive tests to TestTileMscatterOps that mirror the existing FP16/FP32 cases
but use INT16 and INT32 dtypes: create span, rows, cols, tensor_n ConstInt
values; build src_type = ir.TileType([rows, cols], DataType.INT16) and src_type
= ir.TileType([rows, cols], DataType.INT32] respectively, idx_type =
ir.TileType([rows, cols], DataType.INT32), tensor_type =
ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The current type check in tile.mscatter only validates
rank equality (using idx_type->shape_.size() vs src_type->shape_.size()) but
allows same-rank different-shape tensors; update the validation in memory.cpp
(the tile.mscatter type-deduction / CHECK around idx_type and src_type) to
assert full shape equality by comparing idx_type->shape_ and src_type->shape_
elementwise (or via direct equality) and emit a clear error referencing op_name
when they differ so mismatched same-rank shapes are rejected at IR validation
time.

---

Nitpick comments:
In `@tests/st/runtime/test_mscatter.py`:
- Around line 249-572: Add INT16 test coverage by creating test classes
mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase,
MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger
MscatterINT16_16x64RandPermTestCase) that follow the pattern in
MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in
define_tensors with the same init helpers (_init_randint_8x32,
_init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.),
return the corresponding program names (MscatterINT16_8x32Program,
MscatterINT16_16x64Program) from get_program, and implement compute_expected to
create a torch.zeros(..., dtype=torch.int16) and assign
out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten();
keep the tests marked/skipped consistent with existing PTOAS-skipped tests.

In `@tests/ut/ir/operators/test_tile_ops.py`:
- Around line 2009-2142: Add two positive tests to TestTileMscatterOps that
mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span,
rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols],
DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32]
respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type
= ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bf02517c-aad5-4625-9e0f-2b333a52d563

📥 Commits

Reviewing files that changed from the base of the PR and between 523c08f and 432a20e.

📒 Files selected for processing (8)
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/tile_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/memory.cpp
  • tests/st/runtime/test_mscatter.py
  • tests/ut/ir/operators/test_tile_ops.py

Comment on lines +565 to +567
CHECK(idx_type->shape_.size() == src_type->shape_.size())
<< "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size()
<< "), but got " << idx_type->shape_.size();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Enforce idx shape equality with src in tile.mscatter type deduction.

Only rank is validated right now. That allows invalid same-rank/different-shape pairs, which break the op’s per-element semantics and should be rejected at IR validation time.

🔧 Suggested fix
   CHECK(idx_type->shape_.size() == src_type->shape_.size())
       << "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size()
       << "), but got " << idx_type->shape_.size();
+  for (size_t i = 0; i < src_type->shape_.size(); ++i) {
+    CHECK(idx_type->shape_[i]->ToString() == src_type->shape_[i]->ToString())
+        << "The operator " << op_name << " requires idx shape to match src shape at dim " << i
+        << ", but got src dim " << src_type->shape_[i]->ToString()
+        << " and idx dim " << idx_type->shape_[i]->ToString();
+  }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/tile_ops/memory.cpp` around lines 565 - 567, The current type check
in tile.mscatter only validates rank equality (using idx_type->shape_.size() vs
src_type->shape_.size()) but allows same-rank different-shape tensors; update
the validation in memory.cpp (the tile.mscatter type-deduction / CHECK around
idx_type and src_type) to assert full shape equality by comparing
idx_type->shape_ and src_type->shape_ elementwise (or via direct equality) and
emit a clear error referencing op_name when they differ so mismatched same-rank
shapes are rejected at IR validation time.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the mscatter operation, enabling scatter-store functionality from tiles to tensors using per-element indices. The changes span the IR definition, Python DSL wrappers, and C++ backend codegen, supported by new unit and runtime tests. Feedback suggests adding a rank check for the output tensor to ensure it is not a scalar and correcting the import source for the DSL wrapper to align with project conventions.

Comment on lines +570 to +575
auto tensor_type = As<TensorType>(args[2]->GetType());
CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got "
<< args[2]->GetType()->TypeName();
CHECK(tensor_type->dtype_ == src_type->dtype_)
<< "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString()
<< ") to match src dtype (" << src_type->dtype_.ToString() << ")";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mscatter operation implies indexing into the output_tensor, which is not well-defined for a scalar (rank-0) tensor. To prevent potential issues in downstream codegen, it would be safer to add a check to ensure output_tensor has a rank of at least 1. This is consistent with how target tensors are handled in other tile operations.

Suggested change
auto tensor_type = As<TensorType>(args[2]->GetType());
CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got "
<< args[2]->GetType()->TypeName();
CHECK(tensor_type->dtype_ == src_type->dtype_)
<< "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString()
<< ") to match src dtype (" << src_type->dtype_.ToString() << ")";
auto tensor_type = As<TensorType>(args[2]->GetType());
CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got "
<< args[2]->GetType()->TypeName();
CHECK(!tensor_type->shape_.empty()) << "The operator " << op_name
<< " requires a non-scalar output_tensor (rank >= 1)";
CHECK(tensor_type->dtype_ == src_type->dtype_)
<< "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString()
<< ") to match src dtype (" << src_type->dtype_.ToString() << ")";
References
  1. The target tensor of a tile operation is consistently expected to be the 3rd argument (index 2).

Comment on lines +133 to +135
from .op.tile_ops import (
mscatter as mscatter,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Import mscatter from the language.op.system_ops module instead of ir.op.tile_ops. The system_ops module provides the appropriate DSL-level wrappers that accept Tile objects, which is the standard practice in this repository.

from .op.system_ops import mscatter
References
  1. Import tile operations from the language.op.system_ops module, as it provides DSL-level wrappers that accept Tile objects, rather than importing directly from the lower-level ir.op.tile_ops module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[New Op] Add op for pto.mscatter instruction

1 participant