feat(ir): add tile.mscatter op for per-element scatter-store to GM#936
feat(ir): add tile.mscatter op for per-element scatter-store to GM#936Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
…w-native-sys#921) Add tile.mscatter operation mapping to pto.mscatter instruction: mem[idx[i, j]] = src[i, j] - C++ op registration with type deduction and validation - PTO codegen emitting partition_view + pto.mscatter - Python IR and DSL wrappers with pl.mscatter export - Unit tests covering basic usage and error paths - ST runtime tests (skipped: PTOAS lacks NPU mscatter impl)
📝 WalkthroughWalkthroughA new scatter memory operation Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant DSL as DSL Layer<br/>(pl.mscatter)
participant IR as IR Layer<br/>(tile.mscatter)
participant TypeSystem as Type System<br/>(DeduceTileMscatterType)
participant Backend as Backend Codegen
participant PTO as PTO Lowering
User->>DSL: Call pl.mscatter(src_tile, idx_tile, output)
DSL->>IR: Invoke _ir_ops.mscatter(src, idx, output)
IR->>TypeSystem: Deduce output tensor type
TypeSystem->>TypeSystem: Validate src dtype∈{FP16,FP32,INT16,INT32}<br/>Validate idx dtype==INT32<br/>Validate idx.rank==src.rank<br/>Validate output.dtype==src.dtype
TypeSystem-->>IR: Return output_tensor type
IR-->>DSL: Return Tensor(call_expr)
DSL-->>User: Return scattered result tensor
Note over Backend: During Compilation
IR->>Backend: Process tile.mscatter call<br/>(src, idx, output_tensor)
Backend->>PTO: Emit pto.partition_view<br/>(output_tensor)
PTO-->>Backend: Return partition view
Backend->>PTO: Emit pto.mscatter<br/>(src, idx, partition_view)
PTO-->>Backend: Complete scatter operation
Backend-->>IR: Update result to reference<br/>output_tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/st/runtime/test_mscatter.py (1)
249-572: Consider adding INT16 test coverage for completeness.The IR definition in
memory.cpplists INT16 as a supported dtype fortile.mscatter, but the test matrix only covers FP32, FP16, and INT32. Adding an INT16 test case would complete the dtype coverage.Since these tests are currently skipped pending PTOAS support, this can be addressed later.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/st/runtime/test_mscatter.py` around lines 249 - 572, Add INT16 test coverage by creating test classes mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase, MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger MscatterINT16_16x64RandPermTestCase) that follow the pattern in MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in define_tensors with the same init helpers (_init_randint_8x32, _init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.), return the corresponding program names (MscatterINT16_8x32Program, MscatterINT16_16x64Program) from get_program, and implement compute_expected to create a torch.zeros(..., dtype=torch.int16) and assign out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten(); keep the tests marked/skipped consistent with existing PTOAS-skipped tests.tests/ut/ir/operators/test_tile_ops.py (1)
2009-2142: Consider adding INT16/INT32 happy-path cases fortile.mscatter.Current positive tests cover FP32/FP16 only, while the op contract also allows INT16/INT32. Adding those two cases would lock in the full supported dtype matrix.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/ir/operators/test_tile_ops.py` around lines 2009 - 2142, Add two positive tests to TestTileMscatterOps that mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span, rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols], DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32] respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type = ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var, idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType) and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure test names are distinct (e.g., test_tile_mscatter_int16 and test_tile_mscatter_int32) and follow the same pattern as test_tile_mscatter_fp16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The current type check in tile.mscatter only validates
rank equality (using idx_type->shape_.size() vs src_type->shape_.size()) but
allows same-rank different-shape tensors; update the validation in memory.cpp
(the tile.mscatter type-deduction / CHECK around idx_type and src_type) to
assert full shape equality by comparing idx_type->shape_ and src_type->shape_
elementwise (or via direct equality) and emit a clear error referencing op_name
when they differ so mismatched same-rank shapes are rejected at IR validation
time.
---
Nitpick comments:
In `@tests/st/runtime/test_mscatter.py`:
- Around line 249-572: Add INT16 test coverage by creating test classes
mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase,
MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger
MscatterINT16_16x64RandPermTestCase) that follow the pattern in
MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in
define_tensors with the same init helpers (_init_randint_8x32,
_init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.),
return the corresponding program names (MscatterINT16_8x32Program,
MscatterINT16_16x64Program) from get_program, and implement compute_expected to
create a torch.zeros(..., dtype=torch.int16) and assign
out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten();
keep the tests marked/skipped consistent with existing PTOAS-skipped tests.
In `@tests/ut/ir/operators/test_tile_ops.py`:
- Around line 2009-2142: Add two positive tests to TestTileMscatterOps that
mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span,
rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols],
DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32]
respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type
= ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bf02517c-aad5-4625-9e0f-2b333a52d563
📒 Files selected for processing (8)
python/pypto/ir/op/tile_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/memory.cpptests/st/runtime/test_mscatter.pytests/ut/ir/operators/test_tile_ops.py
| CHECK(idx_type->shape_.size() == src_type->shape_.size()) | ||
| << "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size() | ||
| << "), but got " << idx_type->shape_.size(); |
There was a problem hiding this comment.
Enforce idx shape equality with src in tile.mscatter type deduction.
Only rank is validated right now. That allows invalid same-rank/different-shape pairs, which break the op’s per-element semantics and should be rejected at IR validation time.
🔧 Suggested fix
CHECK(idx_type->shape_.size() == src_type->shape_.size())
<< "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size()
<< "), but got " << idx_type->shape_.size();
+ for (size_t i = 0; i < src_type->shape_.size(); ++i) {
+ CHECK(idx_type->shape_[i]->ToString() == src_type->shape_[i]->ToString())
+ << "The operator " << op_name << " requires idx shape to match src shape at dim " << i
+ << ", but got src dim " << src_type->shape_[i]->ToString()
+ << " and idx dim " << idx_type->shape_[i]->ToString();
+ }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/op/tile_ops/memory.cpp` around lines 565 - 567, The current type check
in tile.mscatter only validates rank equality (using idx_type->shape_.size() vs
src_type->shape_.size()) but allows same-rank different-shape tensors; update
the validation in memory.cpp (the tile.mscatter type-deduction / CHECK around
idx_type and src_type) to assert full shape equality by comparing
idx_type->shape_ and src_type->shape_ elementwise (or via direct equality) and
emit a clear error referencing op_name when they differ so mismatched same-rank
shapes are rejected at IR validation time.
There was a problem hiding this comment.
Code Review
This pull request implements the mscatter operation, enabling scatter-store functionality from tiles to tensors using per-element indices. The changes span the IR definition, Python DSL wrappers, and C++ backend codegen, supported by new unit and runtime tests. Feedback suggests adding a rank check for the output tensor to ensure it is not a scalar and correcting the import source for the DSL wrapper to align with project conventions.
| auto tensor_type = As<TensorType>(args[2]->GetType()); | ||
| CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got " | ||
| << args[2]->GetType()->TypeName(); | ||
| CHECK(tensor_type->dtype_ == src_type->dtype_) | ||
| << "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString() | ||
| << ") to match src dtype (" << src_type->dtype_.ToString() << ")"; |
There was a problem hiding this comment.
The mscatter operation implies indexing into the output_tensor, which is not well-defined for a scalar (rank-0) tensor. To prevent potential issues in downstream codegen, it would be safer to add a check to ensure output_tensor has a rank of at least 1. This is consistent with how target tensors are handled in other tile operations.
| auto tensor_type = As<TensorType>(args[2]->GetType()); | |
| CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got " | |
| << args[2]->GetType()->TypeName(); | |
| CHECK(tensor_type->dtype_ == src_type->dtype_) | |
| << "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString() | |
| << ") to match src dtype (" << src_type->dtype_.ToString() << ")"; | |
| auto tensor_type = As<TensorType>(args[2]->GetType()); | |
| CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got " | |
| << args[2]->GetType()->TypeName(); | |
| CHECK(!tensor_type->shape_.empty()) << "The operator " << op_name | |
| << " requires a non-scalar output_tensor (rank >= 1)"; | |
| CHECK(tensor_type->dtype_ == src_type->dtype_) | |
| << "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString() | |
| << ") to match src dtype (" << src_type->dtype_.ToString() << ")"; |
References
- The target tensor of a tile operation is consistently expected to be the 3rd argument (index 2).
| from .op.tile_ops import ( | ||
| mscatter as mscatter, | ||
| ) |
There was a problem hiding this comment.
Import mscatter from the language.op.system_ops module instead of ir.op.tile_ops. The system_ops module provides the appropriate DSL-level wrappers that accept Tile objects, which is the standard practice in this repository.
from .op.system_ops import mscatterReferences
- Import tile operations from the language.op.system_ops module, as it provides DSL-level wrappers that accept Tile objects, rather than importing directly from the lower-level ir.op.tile_ops module.
Summary
Add
tile.mscatteroperation that maps to the PTOASpto.mscatterinstruction for per-element scatter-store from UB tile to GM tensor:Changes
src/ir/op/tile_ops/memory.cpp):tile.mscatterwith type deduction validating src (FP16/FP32/INT16/INT32), idx (INT32, same rank as src), and output_tensor (TensorType, same dtype as src)src/backend/common/pto_ops_common.cpp): Emitspto.partition_view+pto.mscatter ins(src, idx) outs(pview)with row_major layout constraints on inputspython/pypto/ir/op/tile_ops.py):tile.mscatter(src, idx, output_tensor)python/pypto/language/op/tile_ops.py):pl.mscatter(src_tile, idx_tile, out_tensor)exported at top-levelplnamespacetests/ut/ir/operators/test_tile_ops.py): 7 tests covering basic usage (FP32/FP16), error paths (wrong dtype, rank mismatch, arg count)tests/st/runtime/test_mscatter.py): Comprehensive test matrix (FP32/FP16/INT32, 8x32/16x64, sequential/reversed/random/strided indices) — all skipped pending PTOAS NPU implementationNote on ST Tests
All ST tests are marked
pytest.mark.skipbecause PTOAS currently lacks a real NPU implementation forpto.mscatter— it falls back toTSTOREwhich ignores the index tile entirely (PTOToEmitC.cpp:4402). The PyPTO-side IR generation is verified correct via unit tests; enable ST tests once PTOAS ships proper MSCATTER lowering.Testing
Fixes #921