feat(op): implement tensor.scatter_ element-level scatter#898
feat(op): implement tensor.scatter_ element-level scatter#898Little-oil wants to merge 3 commits intohw-native-sys:mainfrom
Conversation
…w-native-sys#677) Add scatter_ op following PyTorch torch.Tensor.scatter_ semantics. Decomposes into nested ForStmt + tile.read/tile.write loops in the ConvertTensorToTileOps pass. Includes ND→2D index flattening, PTO barrier codegen, ConstInt dtype dispatch, and unused Out param reuse for implicit tile.store insertion.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds a new tensor operation Changes
Sequence DiagramsequenceDiagram
participant API as Language API
participant IR as IR Builder
participant Type as Type Deduction
participant Conv as Tensor→Tile Conversion
participant CG as PTO Codegen
participant BE as PTO Backend
API->>IR: scatter_(input, dim, index, src, reduce?)
IR->>IR: parse args, normalize dim, wrap scalar src
IR->>Type: create `tensor.scatter_` Call
Type->>Type: validate ranks, index dtype, dim range -> TensorType
Conv->>Conv: detect `tensor.scatter_`, load index/src as tiles if needed
Conv->>Conv: build nested ForStmt loops, read index/src, compute write indices
Conv->>Conv: tile.write (with optional reduce using tile.read)
Conv->>BE: insert system.bar_v per write and system.bar_all around loop
CG->>CG: emit ConstInt constants per integer dtype (i32/i64/typed arith.constant)
CG->>BE: emit MLIR SSA ops
Estimated code review effort🎯 4 (Complex) | ⏱️ ~80 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the scatter_ operation, which implements element-level scatter semantics similar to PyTorch. The implementation includes the core IR operation, Python DSL bindings, and a conversion pass that lowers the tensor-level operation into nested loops with tile.read and tile.write. Additionally, the PR improves N-D tile handling in the flattening pass and enhances PTO codegen for integer constants and system barriers. Feedback focuses on ensuring the reduce argument is correctly propagated through the Python IR layer, registering it as an official attribute in the C++ op definition, and refactoring duplicated index-flattening logic into a helper method.
| def scatter_( | ||
| input: Expr, | ||
| *args: Expr | int | float, | ||
| dim: int | Expr | None = None, | ||
| index: Expr | None = None, | ||
| src: Expr | float | int | None = None, | ||
| span: Span | None = None, | ||
| ) -> Call: | ||
| """Element-level scatter into tensor along a dimension. | ||
|
|
||
| For each position (i₀,…,iₙ) in index, sets: | ||
| input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ] | ||
|
|
||
| Follows PyTorch ``torch.Tensor.scatter_`` semantics. | ||
|
|
||
| Accepts call forms: | ||
| - scatter_(input, dim, index, src) | ||
| - scatter_(input, dim, index, src=1.0) | ||
|
|
||
| Args: | ||
| input: Destination tensor (N-D). | ||
| dim: Dimension along which to scatter. | ||
| index: Index tensor (same rank as input, integer dtype). | ||
| src: Source tensor (same shape as index) or scalar value. | ||
| span: Optional source span for debugging (auto-captured if not provided). | ||
|
|
||
| Returns: | ||
| Call expression returning the updated input tensor. | ||
| """ | ||
| if len(args) == 3 and dim is None and index is None and src is None: | ||
| dim, index, src = args | ||
| elif len(args) == 2 and dim is not None and index is None and src is None: | ||
| index, src = args | ||
| elif len(args) == 1 and dim is None and index is not None and src is not None: | ||
| dim = args[0] | ||
| elif len(args) != 0: | ||
| raise TypeError( | ||
| "scatter_ expects (input, dim, index, src), " | ||
| "(input, index, src, dim=...), or (input, dim, index=..., src=...)" | ||
| ) | ||
|
|
||
| if dim is None or index is None or src is None: | ||
| raise TypeError("scatter_ requires input, dim, index, and src") | ||
|
|
||
| actual_span = _get_span_or_capture(span) | ||
| if isinstance(dim, ConstInt): | ||
| dim_val = int(dim.value) | ||
| elif isinstance(dim, int): | ||
| dim_val = dim | ||
| else: | ||
| raise TypeError(f"dim must be int or ConstInt, got {type(dim)}") | ||
|
|
||
| if not isinstance(index, Expr): | ||
| raise TypeError(f"index must be Expr, got {type(index)}") | ||
|
|
||
| # src can be Expr or scalar (int/float → ConstFloat) | ||
| if isinstance(src, (int, float)): | ||
| src = ConstFloat(float(src), DataType.FP32, actual_span) | ||
| elif not isinstance(src, Expr): | ||
| raise TypeError(f"src must be Expr or scalar, got {type(src)}") | ||
|
|
||
| op_args: list[Expr] = [input, index, src] | ||
| kwargs: dict[str, Any] = {"dim": dim_val} | ||
| return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span) |
There was a problem hiding this comment.
The language layer pl.scatter_ calls this function with a reduce keyword argument, but this function's signature doesn't accept it. This will lead to a TypeError at runtime. The reduce argument should be added to the signature and passed to create_op_call. Additionally, ensure that user-provided arguments are validated at this level to provide clear error messages.
def scatter_(
input: Expr,
*args: Expr | int | float,
dim: int | Expr | None = None,
index: Expr | None = None,
src: Expr | float | int | None = None,
reduce: str | None = None,
span: Span | None = None,
) -> Call:
"""Element-level scatter into tensor along a dimension.
For each position (i₀,…,iₙ) in index, sets:
input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]
Follows PyTorch ``torch.Tensor.scatter_`` semantics.
Accepts call forms:
- scatter_(input, dim, index, src)
- scatter_(input, dim, index, src=1.0)
Args:
input: Destination tensor (N-D).
dim: Dimension along which to scatter.
index: Index tensor (same rank as input, integer dtype).
src: Source tensor (same shape as index) or scalar value.
reduce: Optional reduction mode ("add" or "multiply").
span: Optional source span for debugging (auto-captured if not provided).
Returns:
Call expression returning the updated input tensor.
"""
if len(args) == 3 and dim is None and index is None and src is None:
dim, index, src = args
elif len(args) == 2 and dim is not None and index is None and src is None:
index, src = args
elif len(args) == 1 and dim is None and index is not None and src is not None:
dim = args[0]
elif len(args) != 0:
raise TypeError(
"scatter_ expects (input, dim, index, src), "
"(input, index, src, dim=...), or (input, dim, index=..., src=...)"
)
if dim is None or index is None or src is None:
raise TypeError("scatter_ requires input, dim, index, and src")
actual_span = _get_span_or_capture(span)
if isinstance(dim, ConstInt):
dim_val = int(dim.value)
elif isinstance(dim, int):
dim_val = dim
else:
raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")
if not isinstance(index, Expr):
raise TypeError(f"index must be Expr, got {type(index)}")
# src can be Expr or scalar (int/float → ConstFloat)
if isinstance(src, (int, float)):
src = ConstFloat(float(src), DataType.FP32, actual_span)
elif not isinstance(src, Expr):
raise TypeError(f"src must be Expr or scalar, got {type(src)}")
op_args: list[Expr] = [input, index, src]
kwargs: dict[str, Any] = {"dim": dim_val}
if reduce is not None:
kwargs["reduce"] = reduce
return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)References
- Validate user-provided arguments for DSL functions at the parser level to provide early and clear error messages, rather than relying solely on backend C++ validation.
There was a problem hiding this comment.
Fixed: pl.scatter_() now passes reduce kwarg through both the IR and language layers.
| REGISTER_OP("tensor.scatter_") | ||
| .set_op_category("TensorOp") | ||
| .set_description( | ||
| "Element-level scatter: write src values into input at positions given by index along dim. " | ||
| "For each element position (i₀,…,iₙ) in index, sets " | ||
| "input[i₀]…[i_{d-1}][index[i₀…iₙ]][i_{d+1}]…[iₙ] = src[i₀…iₙ]. " | ||
| "Supports arbitrary rank and any valid dim ∈ [-rank, rank). " | ||
| "src can be a tensor (same shape as index) or a scalar value.") | ||
| .add_argument("input", "Destination tensor (N-D)") | ||
| .add_argument("index", "Index tensor (N-D, same rank as input) of integer dtype") | ||
| .add_argument("src", "Source tensor (same shape as index) or scalar value") | ||
| .set_attr<int>("dim") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceTensorScatterType(args, kwargs); | ||
| }); |
There was a problem hiding this comment.
The reduce functionality is a core part of tensor.scatter_, as implemented in the conversion pass and exposed in the Python layers. For consistency with the dim attribute and for better IR clarity, reduce should be registered as an optional string attribute of the op. This makes the op's definition self-contained and improves discoverability.
REGISTER_OP("tensor.scatter_")
.set_op_category("TensorOp")
.set_description(
"Element-level scatter: write src values into input at positions given by index along dim. "
"For each element position (i₀,…,iₙ) in index, sets "
"input[i₀]…[i_{d-1}][index[i₀…iₙ]][i_{d+1}]…[iₙ] = src[i₀…iₙ]. "
"Supports arbitrary rank and any valid dim ∈ [-rank, rank). "
"src can be a tensor (same shape as index) or a scalar value. "
"Also supports 'add' and 'multiply' reduction modes.")
.add_argument("input", "Destination tensor (N-D)")
.add_argument("index", "Index tensor (N-D, same rank as input) of integer dtype")
.add_argument("src", "Source tensor (same shape as index) or scalar value")
.set_attr<int>("dim")
.set_attr<std::string>("reduce", /*is_optional=*/true, /*default_value=*/"none")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTensorScatterType(args, kwargs);
});There was a problem hiding this comment.
Fixed: scatter.cpp now stores the reduce attribute via set_attr.
| // ---- tile.read/tile.write on >2D tiles: flatten ND indices to 2D ---- | ||
| // tile.read(tile, (i0, i1, ..., in)) → tile.read(tile_2d, (merged_row, col)) | ||
| // where merged_row = i0 * d1 * d2 * ... * d_{n-2} + i1 * d2 * ... + i_{n-2} | ||
| // and col = i_{n-1} | ||
| // tile.write(tile, (i0, ..., in), val) → tile.write(tile_2d, (merged_row, col), val) | ||
| if (op_name == "tile.read" || op_name == "tile.write") { | ||
| auto orig_tile_type = As<TileType>(call->args_[0]->GetType()); | ||
| if (orig_tile_type && IsNdTile(orig_tile_type)) { | ||
| std::vector<ExprPtr> new_args; | ||
| new_args.reserve(call->args_.size()); | ||
| // args[0]: tile (substitute) | ||
| new_args.push_back(Substitute(call->args_[0], ctx.var_map)); | ||
| // args[1]: indices tuple — flatten from ND to 2D | ||
| auto idx_tuple = As<MakeTuple>(call->args_[1]); | ||
| INTERNAL_CHECK(idx_tuple) << "tile.read/tile.write indices must be MakeTuple"; | ||
| const auto& nd_shape = orig_tile_type->shape_; | ||
| const size_t rank = nd_shape.size(); | ||
| // Compute merged_row = sum of i_k * product(nd_shape[k+1..rank-2]) for k in [0..rank-2) | ||
| ExprPtr merged_row; | ||
| for (size_t k = 0; k + 1 < rank; ++k) { | ||
| ExprPtr term = Substitute(idx_tuple->elements_[k], ctx.var_map); | ||
| // Multiply by trailing dimensions (excluding last) | ||
| for (size_t j = k + 1; j + 1 < rank; ++j) { | ||
| term = MakeMul(term, nd_shape[j], span); | ||
| } | ||
| merged_row = merged_row ? MakeAdd(merged_row, term, span) : term; | ||
| } | ||
| ExprPtr col = Substitute(idx_tuple->elements_[rank - 1], ctx.var_map); | ||
| new_args.push_back(std::make_shared<MakeTuple>( | ||
| std::vector<ExprPtr>{merged_row, col}, span)); | ||
| // Remaining args (e.g., value for tile.write) | ||
| for (size_t i = 2; i < call->args_.size(); ++i) { | ||
| new_args.push_back(Substitute(call->args_[i], ctx.var_map)); | ||
| } | ||
| auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); | ||
| if (op_name == "tile.read") { | ||
| // tile.read returns scalar — assign to var | ||
| auto new_var = std::make_shared<Var>( | ||
| assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); | ||
| result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_)); | ||
| ctx.Insert(assign->var_, new_var); | ||
| } else { | ||
| // tile.write returns tile — assign to var and update mapping | ||
| auto new_var = std::make_shared<Var>( | ||
| assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); | ||
| result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_)); | ||
| ctx.Insert(assign->var_, new_var); | ||
| } | ||
| continue; | ||
| } | ||
| } |
There was a problem hiding this comment.
The logic for flattening N-D indices to 2D for tile.read and tile.write is duplicated here and in the EvalStmt handler for tile.write around line 411. This duplicated logic for calculating merged_row should be extracted into a private helper method to improve maintainability and avoid future inconsistencies.
References
- Extract duplicate logic into a private helper method to improve maintainability and avoid future inconsistencies.
There was a problem hiding this comment.
Fixed: extracted FlattenNdIndicesToTwoD helper to eliminate duplicated code.
There was a problem hiding this comment.
Actionable comments posted: 10
🧹 Nitpick comments (1)
tests/st/runtime/test_scatter.py (1)
268-319: Please add one 3D non-last-dim runtime case.This is the only ND hardware test, and
dim=2keeps the scattered index in the final column. The new ND→2D flattening logic is riskier when the scattered dimension participates inmerged_row(dim=0ordim=1on 3D), so one case there would exercise the new math end to end.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/st/runtime/test_scatter.py` around lines 268 - 319, Add a sibling 3D test where the scatter dimension is not the last axis (e.g., dim=1) to exercise ND→2D flattening when the scattered dim participates in merged_row: copy Scatter3dDim2Program and Scatter3dDim2TestCase to create Scatter3dDim1Program and Scatter3dDim1TestCase, update the kernel/orchestrator to call pl.scatter_(..., dim=1, index=index, src=src), update compute_expected to use expected.scatter_(1, tensors["index"].long(), tensors["src"].float()), and ensure the test uses a proper index init constant (e.g., _IDX_3D_DIM1) with appropriate values selecting positions in the middle dimension.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/en/user/02-operation_reference.md`:
- Line 46: The docs for scatter_ list src as `Tensor | float` but the Python API
accepts integer scalars too; update the operation signature and descriptive text
for `scatter_` so the scalar case includes integers (e.g., `Tensor | int |
float`) and clarify “src can be a tensor (same shape as `index`) or a scalar
(int or float)”; make this change on the `scatter_` entry so the documented API
surface matches the actual `scatter_` Python entry points.
In `@docs/zh-cn/user/02-operation_reference.md`:
- Line 43: The Chinese doc for scatter_ currently lists src as `Tensor | float`
but the implementation accepts integer scalars too; update the signature in the
`scatter_` entry to `Tensor | int | float` (matching the English reference and
the implementation) so the zh-CN scalar type is in sync with the actual API, and
ensure the descriptive sentence mentions that `src` may be a tensor or a scalar
(int or float).
In `@python/pypto/ir/op/tensor_ops.py`:
- Around line 1000-1004: The current coercion in tensor_ops.py turns integer
scalar inputs into ConstFloat(DataType.FP32) which is wrong; change the handling
for src so that ints become ConstInt (e.g. ConstInt(int(src), appropriate
integer dtype, actual_span)) instead of ConstFloat, leave floats as ConstFloat,
and keep the isinstance(src, Expr) check; after wrapping, validate and/or cast
the resulting ConstInt/ConstFloat against the destination/input tensor dtype
(use the input's dtype validation codepath already used elsewhere) so integer
literals are preserved and properly converted to the tensor's dtype rather than
always becoming FP32; update references to ConstFloat, ConstInt, DataType.FP32
and the src variable to implement this logic.
- Around line 945-952: The wrapper forwards a reduce=... kwarg into the IR
builder but the IR builder function scatter_ (in
python/pypto/ir/op/tensor_ops.py) does not accept it, causing a TypeError; fix
by adding a reduce parameter to scatter_ (e.g., reduce: Expr | str | None =
None) and propagate/validate it into the Call construction (or explicitly coerce
allowed values like "sum"/"mean"/"prod"), or alternatively modify the language
wrapper to stop forwarding reduce; update the scatter_ signature and any
internal packing/args handling so reduce is accepted, validated, and passed
through to whatever Call or node creation uses the other params (input, dim,
index, src, span) to keep pl.scatter_ working end-to-end.
In `@python/pypto/language/op/tensor_ops.py`:
- Around line 785-813: The scatter_ implementation currently only unwraps Tensor
sources, so when src is a DSL Scalar wrapper it gets forwarded as the wrapper
instead of the underlying IR expression; update scatter_ (function scatter_) to
detect and unwrap Scalar as well (i.e., if isinstance(src, Tensor) use
src.unwrap(), elif isinstance(src, Scalar) use src.unwrap(), else pass the raw
Python scalar) before calling _ir_ops.scatter_, and ensure the Scalar symbol is
imported where needed.
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 959-965: MakeSystemBarrierCodegenPTO currently ignores any
operands on the `op` and always emits a barrier; add a fast-fail guard that
validates the call is zero-argument and reports a clear error if not. Inside
MakeSystemBarrierCodegenPTO (using the `CallPtr& op`), check the
operand/argument count (e.g., `op->args().size()` or `op->numArgs()` /
`op->operands().size()` depending on your CallPtr API) and if it is non-zero
throw or log a fatal error (e.g., throw std::runtime_error with a message that
includes the op identity or text) instead of emitting the barrier; keep the
existing emit (`codegen.Emit("pto.barrier `#pto.pipe`<" + pipe_name + ">");`) only
when the argument count is zero.
In `@src/codegen/pto/pto_codegen.cpp`:
- Around line 909-915: GetExprAsCode() currently maps only INT32 ConstInt to
GetOrEmitI32Constant and sends every other integer to GetIndexConstant, causing
MLIR type mismatches versus GetExprTypeAnnotation(); update the ConstInt branch
in GetExprAsCode() to mirror the full typed-constant handling used in the
ConstInt visitor: inspect const_int->dtype() and dispatch to the appropriate
emitter (e.g., GetOrEmitI32Constant, GetOrEmitI64Constant,
GetOrEmitUI64Constant, GetOrEmitI8Constant, GetOrEmitUI8Constant, etc.) for each
DataType rather than defaulting to GetIndexConstant so the emitted constant type
matches the annotated MLIR type.
In `@src/ir/op/tensor_ops/scatter.cpp`:
- Around line 55-85: DeduceTensorScatterType currently accepts many invalid
src/index combos; update it (inside DeduceTensorScatterType near use of args,
index_type, src_is_scalar, src_type, kwargs and dim handling) to additionally
validate: for non-scalar src, require src_type->shape_.size() == rank (already
checked) and enforce that for every axis k != dim the src_type->shape_[k] ==
index_type->shape_[k] and also src_type->shape_[k] <= input_type->shape_[k]; for
the scatter axis require src_type->shape_[dim] <= input_type->shape_[dim]
(reject if larger); for scalar src (src_is_scalar) check the literal dtype is
implicitly convertible/writable into input_type->dtype_ (use the same dtype
comparison/convertibility logic used elsewhere) and reject incompatible scalar
dtypes; keep using the existing CHECK macro and error wording consistent with
other messages.
In `@src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`:
- Around line 1763-1783: The code only checks unused_out_params.begin() for a
matching Out param, causing missed compatible params later in the set; change
the logic in the block handling unused_out_params so you iterate through all
entries in unused_out_params (e.g., a for-loop over unused_out_params) and for
each candidate verify dtype and shape equality (use As<TensorType>(*) and
As<ConstInt> on dimensions, comparing values as in the existing checks); when
you find the first matching entry, set out_param, set is_existing_param = true
and erase that entry from unused_out_params, then break out of the loop.
- Around line 723-733: The shortcut only handles As<Var>(new_result) but misses
when new_result is an IterArg (e.g. scatter_ returning the loop-carried
IterArg), so detect IterArg aliases as well: after the As<Var> branch, check if
new_result can be cast to IterArg (As<IterArg>(new_result)) and record the same
aliasing in var_remap_ (mapping the original op->var_ to that IterArg alias),
then remove the redundant assignment from stmts and return the prologue SeqStmts
exactly as done for the Var case so the synthetic assignment is suppressed for
IterArg results too.
---
Nitpick comments:
In `@tests/st/runtime/test_scatter.py`:
- Around line 268-319: Add a sibling 3D test where the scatter dimension is not
the last axis (e.g., dim=1) to exercise ND→2D flattening when the scattered dim
participates in merged_row: copy Scatter3dDim2Program and Scatter3dDim2TestCase
to create Scatter3dDim1Program and Scatter3dDim1TestCase, update the
kernel/orchestrator to call pl.scatter_(..., dim=1, index=index, src=src),
update compute_expected to use expected.scatter_(1, tensors["index"].long(),
tensors["src"].float()), and ensure the test uses a proper index init constant
(e.g., _IDX_3D_DIM1) with appropriate values selecting positions in the middle
dimension.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2cd416c7-1b1f-432b-a1f9-209749e81ef3
📒 Files selected for processing (18)
CMakeLists.txtdocs/en/user/02-operation_reference.mddocs/zh-cn/user/02-operation_reference.mdpython/pypto/debug/torch_codegen.pypython/pypto/ir/op/tensor_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/tensor_ops.pysrc/backend/common/pto_ops_common.cppsrc/codegen/pto/pto_codegen.cppsrc/codegen/pto/pto_scalar_expr_codegen.cppsrc/ir/op/tensor_ops/scatter.cppsrc/ir/transforms/convert_tensor_to_tile_ops_pass.cppsrc/ir/transforms/flatten_tile_nd_to_2d_pass.cppsrc/ir/transforms/op_conversion_registry.cpptests/st/runtime/test_scatter.pytests/ut/ir/operators/test_tensor_ops.pytests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py
| auto index_type = As<TensorType>(args[1]->GetType()); | ||
| CHECK(index_type) << "tensor.scatter_: index must be TensorType, got " << args[1]->GetType()->TypeName(); | ||
| CHECK(index_type->shape_.size() == rank) | ||
| << "tensor.scatter_: index rank (" << index_type->shape_.size() << ") must match input rank (" << rank | ||
| << ")"; | ||
| CHECK(index_type->dtype_.IsInt()) << "tensor.scatter_: index dtype must be integer, got " | ||
| << index_type->dtype_.ToString(); | ||
|
|
||
| // src can be TensorType or scalar (ConstFloat / ConstInt) | ||
| bool src_is_scalar = As<ConstFloat>(args[2]) || As<ConstInt>(args[2]); | ||
| if (!src_is_scalar) { | ||
| auto src_type = As<TensorType>(args[2]->GetType()); | ||
| CHECK(src_type) << "tensor.scatter_: src must be TensorType or scalar, got " | ||
| << args[2]->GetType()->TypeName(); | ||
| CHECK(src_type->shape_.size() == rank) | ||
| << "tensor.scatter_: src rank (" << src_type->shape_.size() << ") must match input rank (" << rank | ||
| << ")"; | ||
| CHECK(src_type->dtype_ == input_type->dtype_) | ||
| << "tensor.scatter_: src dtype (" << src_type->dtype_.ToString() << ") must match input dtype (" | ||
| << input_type->dtype_.ToString() << ")"; | ||
| } | ||
|
|
||
| // Validate dim kwarg | ||
| for (const auto& [key, val] : kwargs) { | ||
| if (key == "dim") { | ||
| int dim_val = AnyCast<int>(val, "kwarg key: dim"); | ||
| int irank = static_cast<int>(rank); | ||
| CHECK(dim_val >= -irank && dim_val < irank) | ||
| << "tensor.scatter_: dim must be in [" << -irank << ", " << irank << ") for " << rank | ||
| << "D input, got " << dim_val; | ||
| } |
There was a problem hiding this comment.
Reject impossible scatter shapes and scalar/input dtype pairs here.
DeduceTensorScatterType currently accepts any tensor src with the right rank and any scalar literal at all. That means obviously invalid programs—e.g. index larger than input on a non-scatter axis, tensor src smaller than the index iteration space, or a scalar literal whose dtype cannot be written into input—pass type deduction and only fail later in the elementwise lowering/codegen path. Tighten those checks before returning the result TensorType.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/op/tensor_ops/scatter.cpp` around lines 55 - 85,
DeduceTensorScatterType currently accepts many invalid src/index combos; update
it (inside DeduceTensorScatterType near use of args, index_type, src_is_scalar,
src_type, kwargs and dim handling) to additionally validate: for non-scalar src,
require src_type->shape_.size() == rank (already checked) and enforce that for
every axis k != dim the src_type->shape_[k] == index_type->shape_[k] and also
src_type->shape_[k] <= input_type->shape_[k]; for the scatter axis require
src_type->shape_[dim] <= input_type->shape_[dim] (reject if larger); for scalar
src (src_is_scalar) check the literal dtype is implicitly convertible/writable
into input_type->dtype_ (use the same dtype comparison/convertibility logic used
elsewhere) and reject incompatible scalar dtypes; keep using the existing CHECK
macro and error wording consistent with other messages.
There was a problem hiding this comment.
Fixed: added dtype pair validation in scatter.cpp — float src requires float input, int src requires int input.
There was a problem hiding this comment.
🧩 Analysis chain
@coderabbit-user
Thanks for the quick update — I see you’ve added the scalar dtype pair validation (float→float, int→int). That covers the dtype part of the request.
Before we can mark this thread as addressed, can you also add the remaining shape validations in DeduceTensorScatterType?
- For non-scalar src:
- For every axis k ≠ dim: enforce src_type->shape_[k] == index_type->shape_[k] and src_type->shape_[k] <= input_type->shape_[k].
- For the scatter axis dim: enforce src_type->shape_[dim] <= input_type->shape_[dim].
- For scalar src: you’ve added dtype checks; that’s good. If your project has a central “is-writable-to” helper for dtype convertibility, using it here would be ideal; otherwise the float↔int disallow rule you added is acceptable.
If helpful, here’s a targeted patch you can adapt inside DeduceTensorScatterType right after you parse dim and have input_type/index_type/src_is_scalar:
// Normalize dim to [0, rank)
int dim_norm = 0;
for (const auto& [key, val] : kwargs) {
if (key == "dim") {
int d = AnyCast<int>(val, "kwarg key: dim");
int irank = static_cast<int>(rank);
CHECK(d >= -irank && d < irank) << "tensor.scatter_: dim must be in [" << -irank << ", " << irank
<< ") for " << rank << "D input, got " << d;
dim_norm = d < 0 ? d + irank : d;
break;
}
}
// Additional shape checks
if (!src_is_scalar) {
auto src_type = As<TensorType>(args[2]->GetType());
// already CHECKed rank and dtype match earlier
for (size_t k = 0; k < rank; ++k) {
const auto& idx_dim = index_type->shape_[k];
const auto& in_dim = input_type->shape_[k];
const auto& src_dim = src_type->shape_[k];
if (static_cast<int>(k) != dim_norm) {
CHECK(src_dim == idx_dim)
<< "tensor.scatter_: src shape must match index shape on non-scatter axes; "
<< "mismatch at axis " << k << " (src " << src_dim.ToString() << " vs index "
<< idx_dim.ToString() << ")";
CHECK(src_dim <= in_dim)
<< "tensor.scatter_: src extent exceeds input extent at axis " << k << " ("
<< src_dim.ToString() << " > " << in_dim.ToString() << ")";
} else {
CHECK(src_dim <= in_dim)
<< "tensor.scatter_: src extent exceeds input extent on scatter axis " << dim_norm
<< " (" << src_dim.ToString() << " > " << in_dim.ToString() << ")";
}
}
} else {
// scalar src: ensure scalar dtype is writable to input dtype
if (As<ConstFloat>(args[2])) {
CHECK(input_type->dtype_.IsFloat())
<< "tensor.scatter_: float scalar src requires float input dtype, got "
<< input_type->dtype_.ToString();
} else if (As<ConstInt>(args[2])) {
CHECK(input_type->dtype_.IsInt())
<< "tensor.scatter_: int scalar src requires int input dtype, got "
<< input_type->dtype_.ToString();
}
}Once these guards are in, invalid shapes will be rejected at type deduction instead of failing later in lowering/codegen. Let me know if you’d like me to expand this into a full diff for the file.
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Locate scatter.cpp =="
fd -a --strip-cwd-prefix 'scatter.cpp'
echo
echo "== Show DeduceTensorScatterType signature and surrounding logic =="
rg -n -C4 --type=cpp 'DeduceTensorScatterType\\(' src/ir/op/tensor_ops/scatter.cpp || true
echo
echo "== Show checks around src/type handling (scalar vs tensor) =="
rg -n -C3 --type=cpp 'src_is_scalar|As<ConstFloat>|As<ConstInt>|src_type|dtype_' src/ir/op/tensor_ops/scatter.cpp || true
echo
echo "== Show dim kwarg handling =="
rg -n -C3 --type=cpp 'key == "dim"|dim_val' src/ir/op/tensor_ops/scatter.cpp || true
Length of output: 2740
@Little-oil, thanks — I've confirmed the scalar dtype checks are now in place (lines 75–83). That part is done ✅
However, the per-axis shape constraints for non-scalar src are still missing. The current code only checks that src rank matches and src dtype matches input, but does not verify the extents are compatible:
| Axis | Missing check |
|---|---|
| k ≠ dim | src_type->shape_[k] == index_type->shape_[k] |
| k ≠ dim | src_type->shape_[k] <= input_type->shape_[k] |
| k == dim | src_type->shape_[k] <= input_type->shape_[k] |
Without these, a caller can pass a src that is larger than input on the scatter axis (or mismatched with index on other axes) and the error won't surface until deep in the loop-based lowering in op_conversion_registry.cpp.
Here is a targeted patch to add inside the if (!src_is_scalar) block (after retrieving dim_norm):
// Normalize dim once so we can do per-axis checks
int dim_norm = 0;
for (const auto& [key, val] : kwargs) {
if (key == "dim") {
int d = AnyCast<int>(val, "kwarg key: dim");
dim_norm = d < 0 ? d + static_cast<int>(rank) : d;
break;
}
}
if (!src_is_scalar) {
// ... existing rank/dtype checks ...
// Per-axis shape compatibility
for (size_t k = 0; k < rank; ++k) {
if (static_cast<int>(k) != dim_norm) {
CHECK(src_type->shape_[k] == index_type->shape_[k])
<< "tensor.scatter_: src shape must equal index shape on non-scatter axis " << k
<< " (src=" << src_type->shape_[k].ToString()
<< ", index=" << index_type->shape_[k].ToString() << ")";
CHECK(src_type->shape_[k] <= input_type->shape_[k])
<< "tensor.scatter_: src extent exceeds input extent on non-scatter axis " << k
<< " (src=" << src_type->shape_[k].ToString()
<< ", input=" << input_type->shape_[k].ToString() << ")";
} else {
CHECK(src_type->shape_[k] <= input_type->shape_[k])
<< "tensor.scatter_: src extent exceeds input extent on scatter axis " << dim_norm
<< " (src=" << src_type->shape_[k].ToString()
<< ", input=" << input_type->shape_[k].ToString() << ")";
}
}
}Note: you'll want to move the dim_norm computation before the if (!src_is_scalar) block, and de-duplicate it with the existing dim range-check loop (lines 86–94) so dim is normalized and range-checked in one pass.
🧠 Learnings used
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 431
File: src/ir/op/tensor_ops/broadcast.cpp:56-79
Timestamp: 2026-03-10T09:46:22.193Z
Learning: In `src/ir/op/tensor_ops/broadcast.cpp` and `src/ir/op/tile_ops/broadcast.cpp` (pypto project), the rank checks in `DeduceTensorRowExpandType`, `DeduceTensorColExpandType`, and their tile-level equivalents intentionally use `>= 2` (not `== 2`). This allows row_expand_mul and col_expand_mul to operate on higher-rank tensors (e.g., `[B, M, N]`), applying the broadcast over the last two dimensions. Do not flag this as a bug or suggest tightening to exact rank-2 checks.
Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/op/tensor_ops/matmul.cpp:133-138
Timestamp: 2026-04-01T08:56:12.231Z
Learning: In hw-native-sys/pypto `src/ir/op/tensor_ops/matmul.cpp`, `tensor.matmul` is strictly 2D-only and `tensor.batch_matmul` requires both operands to have rank >= 3. Mixed-rank broadcasting (e.g., 2D×3D) is intentionally out of scope and requires explicit unsqueeze by the user. Do not flag the dual rank >= 3 guards in `DeduceTensorBatchMatMulType` as a missing-broadcast-case bug.
Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/op_conversion_registry.cpp:307-312
Timestamp: 2026-04-02T02:12:12.162Z
Learning: In hw-native-sys/pypto `src/ir/transforms/op_conversion_registry.cpp` (around lines 307–312), the `tensor.batch_matmul` conversion lambda intentionally requires both `lhs_shape.size() >= 3` and `rhs_shape.size() >= 3`. This mirrors the rank >= 3 constraint in `DeduceTensorBatchMatMulType` in `src/ir/op/tensor_ops/matmul.cpp`. Mixed-rank 2D×3D broadcasting is explicitly out of scope for this PR; 2D inputs must use `tensor.matmul`. Do not flag these guards as an over-restriction of the batch matmul conversion.
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 405
File: src/ir/op/sync_ops/cross_core.cpp:57-67
Timestamp: 2026-03-08T15:07:33.101Z
Learning: In `src/ir/op/sync_ops/cross_core.cpp` and `python/pypto/ir/op/system_ops.py` (pypto project), the zero-argument `system.tpop_from_aic` and `system.tpop_from_aiv` ops intentionally fall back to `UnknownType` via `create_op_call` when no `result_type`/`shape`/`dtype` is provided. The C++ `ExpandMixedKernel` pass emits these tpop calls without explicit types; the concrete result type is resolved from the LHS variable's type annotation during DSL parsing. Rejecting no-type calls at the Python IR layer would break the pass. The existing validation in `_resolve_tpop_type` (mutual exclusivity of `result_type` vs `shape`/`dtype`, both-or-neither for `shape`+`dtype`) is sufficient.
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 834
File: src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp:455-503
Timestamp: 2026-04-01T16:17:03.789Z
Learning: In hw-native-sys/pypto `src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`, `MatmulSlicePatternCollector::targets_` (and the old `PreScanSliceMatmulPatterns`) intentionally stores only one `MatmulSliceInfo` per slice result var — if the same slice feeds multiple matmul consumers, the last recorded consumer wins. This is a pre-existing limitation carried over from `PreScanSliceMatmulPatterns`; do not flag the absence of a multi-consumer guard in `CollectMatmulOperands` as a regression introduced by the refactoring. A follow-up improvement should add a guard that falls back to generic Vec-space slice conversion when consumers are incompatible.
Learnt from: lyfne123
Repo: hw-native-sys/pypto PR: 732
File: src/codegen/pto/pto_scalar_expr_codegen.cpp:221-226
Timestamp: 2026-03-26T07:19:54.729Z
Learning: In hw-native-sys/pypto `src/codegen/pto/pto_scalar_expr_codegen.cpp`, the `And`/`Or`/`Xor`/`BitAnd`/`BitOr`/`BitXor` `VisitExpr_` visitors call `VisitBinaryArithExpr` with identical `int_op` and `float_op` MLIR op strings (e.g., `"arith.andi"` for both). This is intentional and safe: `MakeAnd`/`MakeOr` factory functions enforce `DataType::BOOL` on operands, and `MakeBitAnd`/`MakeBitOr`/`MakeBitXor` enforce integer types via `PromoteIntBinaryOperands`. Float operands structurally cannot reach these visitors. Do not flag the identical int/float op strings as a missing float-guard bug in codegen PRs.
Learnt from: zhangqi-chen
Repo: hw-native-sys/pypto PR: 833
File: src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp:1688-1716
Timestamp: 2026-04-01T14:12:07.565Z
Learning: In hw-native-sys/pypto `src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`, within Phase 3a of `TransformIncoreFunction`, the `sink_candidates` loop over IfStmt return vars is guaranteed to never produce two candidates sharing the same `ifstmt_rv_index`. A function's return statement will never reference the same IfStmt return var in multiple positions with different iter-arg mappings. Do not flag the absence of a duplicate-index guard in this loop as a correctness issue.
Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp:610-625
Timestamp: 2026-04-01T08:56:31.235Z
Learning: In hw-native-sys/pypto `src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`, the `DetectDirectStore` fusion in `LowerBatchMatmul` does not need an explicit single-use check on the `tile.batch_matmul` result var. `ConvertTensorToTileOps` always emits `tile.batch_matmul → tile.store` as consecutive statements, so the batch_matmul result is invariably single-use in generated IR. In the fused path, `assign->var_` is intentionally left unmapped in `ctx.var_map` (only the store var is remapped via `ctx.Insert(lowering.store_orig_var, lowering.store_result_var)`), which is safe because no later references to the batch_matmul result exist. Do not flag the absence of a single-use guard in `DetectDirectStore` as a correctness issue.
Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/op_conversion_registry.cpp:314-332
Timestamp: 2026-04-01T09:49:00.207Z
Learning: In hw-native-sys/pypto `src/ir/transforms/op_conversion_registry.cpp`, the `tensor.batch_matmul` conversion lambda intentionally does NOT forward the `out_dtype` kwarg to `tile.batch_matmul`. The `out_dtype` is handled in two phases: (1) `DeduceTensorBatchMatMulType` in `src/ir/op/tensor_ops/matmul.cpp` records the requested `out_dtype` in the output `TensorType`; (2) `FlattenTileNdTo2D` in `src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` (around lines 753–774) detects when the tile.matmul accumulator dtype (FP32/INT32) differs from the expected result dtype and inserts a cast. This two-phase pattern is the same as `tensor.matmul`. Do not flag the missing `out_dtype` forwarding in the conversion as a bug.
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 386
File: src/ir/op/tile_ops/reduction.cpp:160-170
Timestamp: 2026-03-06T08:17:18.156Z
Learning: In `src/ir/op/tile_ops/reduction.cpp`, the `REGISTER_OP("tile.sum")` registration intentionally includes a second positional argument `tmp_tile` (described as "Temporary tile (TileType)"), which is used as a hardware workspace tile for reduction. However, `DeduceTileReductionType` currently CHECKs for exactly 1 positional argument. The arity mismatch between the registration and the deduce function may need to be reconciled separately — either by updating `DeduceTileReductionType` to accept and handle `args[1]`, or by confirming the `tmp_tile` is not passed through the type-deduction path.
Learnt from: luohuan19
Repo: hw-native-sys/pypto PR: 786
File: src/codegen/tensor_op_codegen.cpp:61-95
Timestamp: 2026-03-30T17:23:12.620Z
Learning: In hw-native-sys/pypto `src/codegen/tensor_op_codegen.cpp`, the `tensor.create` handler (REGISTER_ORCHESTRATION_OP(tensor_create)) intentionally emits two declarations for every internal tensor: (1) `TensorCreateInfo <var>_ci(...)` used with `add_output(<var>_ci)` for runtime device-memory allocation, and (2) `Tensor <var> = make_tensor_external(nullptr, ...)` as a null-addr placeholder that receives a real address only after `<var> = outs.get_ref(i)`. The codegen tracks these variables in `tensor_create_var_names_` (orchestration_codegen.cpp:928) and guarantees they are always routed through `add_output` before any `add_input` use. DN tensors likewise start as null-addr internally. Do not flag `make_tensor_external(nullptr, ...)` in tensor.create as a use-before-initialize bug; the output-first ordering is enforced by the codegen.
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 308
File: docs/zh-cn/dev/ir/05-operators.md:80-96
Timestamp: 2026-03-01T09:05:27.522Z
Learning: In the PyPTO C++ codebase, kwargs should be stored as std::vector<std::pair<std::string, std::any>> and looked up with range-based for loops (linear search). Do not switch to std::unordered_map for kwargs. Use the GetKwarg helper in src/ir/op/block_ops/memory.cpp as the standard reference for how this pattern is implemented.
Learnt from: lyfne123
Repo: hw-native-sys/pypto PR: 319
File: src/ir/transforms/pass_context.cpp:0-0
Timestamp: 2026-03-02T08:47:16.111Z
Learning: Do not rely on C++ filesystem for directory creation. cpplint disallows the <filesystem> header; ensure directories for report output are created by the Python layer (e.g., in compile.py using os.makedirs) and pass pre-created directory paths to C++ components like ReportInstrument via their constructors. This pattern should apply to all C++ source files that handle output directories.
- Add int src support to scatter_ signature (docs, DSL, IR validation) - Add reduce kwarg forwarding in scatter_ DSL and IR layers - Extract duplicate ND→2D flatten logic into FlattenNdIndicesToTwoD helper - Add barrier argument validation in MakeSystemBarrierCodegenPTO - Handle INT64 and general integer ConstInt in PTO codegen - Fix scatter UT: add Out params to Before programs matching pass output - Fix InOut param test: align assertion with post-rebase pass behavior
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
python/pypto/ir/op/tensor_ops.py (1)
1002-1007:⚠️ Potential issue | 🟡 MinorInfer scalar dtype from input tensor dtype instead of hardcoding INT32/FP32.
Lines 1002–1007 create scalars with fixed dtypes regardless of the input tensor's actual dtype. If
inputis INT64 or FP16, the hardcoded INT32/FP32 scalar will cause a dtype precision mismatch. The C++ type deduction only validates category (IsInt/IsFloat), not exact dtype. Pass the input tensor's dtype to ConstInt and ConstFloat to ensure compatibility downstream.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/ir/op/tensor_ops.py` around lines 1002 - 1007, The scalar constructors currently hardcode DataType.INT32/DataType.FP32; instead, when wrapping Python scalars (the int/float branches) pass the target tensor's dtype so the scalar matches the input tensor precision—use the input tensor's dtype (e.g., input.dtype or the relevant tensor Expr's .dtype) as the dtype argument to ConstInt and ConstFloat rather than DataType.INT32/DataType.FP32; keep the Expr type check and the TypeError as-is.
🧹 Nitpick comments (2)
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp (2)
663-699: tile.read/tile.write ND→2D index flattening is correctly implemented.The implementation:
- Only triggers for >2D tiles (via
IsNdTilecheck)- Validates indices are
MakeTuplewithINTERNAL_CHECK- Flattens indices using the helper function
- Correctly handles both ops (read returns scalar, write returns tile)
However, the code at lines 684-696 has identical logic for both branches (both create a new var and insert into var_map). This could be simplified.
♻️ Simplify duplicated branches
auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); - if (op_name == "tile.read") { - // tile.read returns scalar — assign to var - auto new_var = - std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); - result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_)); - ctx.Insert(assign->var_, new_var); - } else { - // tile.write returns tile — assign to var and update mapping - auto new_var = - std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); - result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_)); - ctx.Insert(assign->var_, new_var); - } + // Both tile.read (scalar) and tile.write (tile) produce a result to assign + auto new_var = + std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); + result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_)); + ctx.Insert(assign->var_, new_var); continue;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 663 - 699, The two branches handling op_name == "tile.read" and else (tile.write) duplicate the same logic creating new_var, pushing AssignStmt, and calling ctx.Insert; collapse them into a single shared block after new_call is created: construct new_var with assign->var_->name_hint_/GetType()/span_, push the new AssignStmt to result, and call ctx.Insert(assign->var_, new_var); remove the separate per-op duplicate code paths (keep op_name checks only for any op-specific behavior if later needed). Use symbols new_call, assign, result, ctx.Insert, Var, and AssignStmt to locate and replace the duplicated code.
71-76: Edge case: single-element rank-2 input returns uninitializedmerged_row.When
rank == 2, the outer loop conditionk + 1 < rankmeansk < 1, soktakes only value0. The inner loopj = k + 1 = 1checksj + 1 < ranki.e.2 < 2which is false, so no multiplication happens. This is correct behavior.However, if
rank == 1(single dimension), the loop body never executes andmerged_rowremainsnullptr. WhileIsNdTileguards ensurerank > 2before calling this function, there's no explicit precondition check.🛡️ Consider adding a defensive CHECK
ExprPtr FlattenNdIndicesToTwoD(const MakeTuplePtr& idx_tuple, const std::vector<ExprPtr>& nd_shape, const std::unordered_map<const Var*, VarPtr>& var_map, const Span& span) { const size_t rank = nd_shape.size(); + CHECK(rank >= 2) << "FlattenNdIndicesToTwoD requires rank >= 2, got " << rank; ExprPtr merged_row;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 71 - 76, Add a defensive precondition and explicit handling for tiny ranks so merged_row can't remain null: at the start of the function that contains the loop (referencing merged_row, rank, idx_tuple and the loop using Substitute/MakeMul/MakeAdd) add a CHECK or assert that rank >= 2 (or return/handle rank == 1 explicitly), and if rank == 1 set merged_row to Substitute(idx_tuple->elements_[0], var_map) before the loop so the subsequent logic always has a valid merged_row; keep the existing loop intact for rank >= 2.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/pypto/language/op/tensor_ops.py`:
- Line 806: Update the docstring for the parameter named reduce in
python/pypto/language/op/tensor_ops.py to list all valid values used by the
backend — "none", "add", and "multiply" — so it matches the validation in the
C++ op conversion registry; locate the docstring block that documents reduce and
replace the current quoted values ("add", "multiply") with the full set ("none",
"add", "multiply") and ensure the description clearly states what each mode
means if present.
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 729-748: The code only casts ir::Var in the ensure_index lambda,
so non-Var index expressions (ConstInt, IterArg, or compound Expr like add) keep
their integer scalar type and produce ill-typed arith.muli/arith.addi ... :
index ops; update ensure_index (used above the loop that computes term and in
the flat-offset math) to inspect the expression's type or kind and call
codegen.EmitCastToIndex for any expression that is not already of index type
(not just ir::Var), ensuring term is always the casted SSA string; use the same
pattern around codegen.GetExprAsCode(indices[i]) so all terms used in the
subsequent arith.muli/arith.addi emissions are index-typed.
---
Duplicate comments:
In `@python/pypto/ir/op/tensor_ops.py`:
- Around line 1002-1007: The scalar constructors currently hardcode
DataType.INT32/DataType.FP32; instead, when wrapping Python scalars (the
int/float branches) pass the target tensor's dtype so the scalar matches the
input tensor precision—use the input tensor's dtype (e.g., input.dtype or the
relevant tensor Expr's .dtype) as the dtype argument to ConstInt and ConstFloat
rather than DataType.INT32/DataType.FP32; keep the Expr type check and the
TypeError as-is.
---
Nitpick comments:
In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`:
- Around line 663-699: The two branches handling op_name == "tile.read" and else
(tile.write) duplicate the same logic creating new_var, pushing AssignStmt, and
calling ctx.Insert; collapse them into a single shared block after new_call is
created: construct new_var with assign->var_->name_hint_/GetType()/span_, push
the new AssignStmt to result, and call ctx.Insert(assign->var_, new_var); remove
the separate per-op duplicate code paths (keep op_name checks only for any
op-specific behavior if later needed). Use symbols new_call, assign, result,
ctx.Insert, Var, and AssignStmt to locate and replace the duplicated code.
- Around line 71-76: Add a defensive precondition and explicit handling for tiny
ranks so merged_row can't remain null: at the start of the function that
contains the loop (referencing merged_row, rank, idx_tuple and the loop using
Substitute/MakeMul/MakeAdd) add a CHECK or assert that rank >= 2 (or
return/handle rank == 1 explicitly), and if rank == 1 set merged_row to
Substitute(idx_tuple->elements_[0], var_map) before the loop so the subsequent
logic always has a valid merged_row; keep the existing loop intact for rank >=
2.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4b56488a-1733-4d02-b7e4-6e85cdc4ea04
📒 Files selected for processing (14)
docs/en/user/02-operation_reference.mddocs/zh-cn/user/02-operation_reference.mdpython/pypto/ir/op/tensor_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/tensor_ops.pysrc/backend/common/pto_ops_common.cppsrc/codegen/pto/pto_codegen.cppsrc/ir/op/tensor_ops/scatter.cppsrc/ir/transforms/convert_tensor_to_tile_ops_pass.cppsrc/ir/transforms/flatten_tile_nd_to_2d_pass.cppsrc/ir/transforms/op_conversion_registry.cpptests/st/runtime/test_scatter.pytests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py
🚧 Files skipped from review as they are similar to previous changes (6)
- docs/en/user/02-operation_reference.md
- python/pypto/language/op/init.py
- src/codegen/pto/pto_codegen.cpp
- src/ir/op/tensor_ops/scatter.cpp
- docs/zh-cn/user/02-operation_reference.md
- src/ir/transforms/op_conversion_registry.cpp
| dim: Dimension along which to scatter | ||
| index: Index tensor (N-D, same rank as input) of integer dtype | ||
| src: Source tensor (same shape as index) or scalar value | ||
| reduce: Optional reduce mode ("add" or "multiply") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if reduce validation exists anywhere in the scatter implementation chain
rg -n "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp | head -20
rg -n "reduce.*add\|reduce.*multiply" --type cpp src/ir/transforms/op_conversion_registry.cpp | head -10Repository: hw-native-sys/pypto
Length of output: 103
🏁 Script executed:
# Check the context around line 806 in tensor_ops.py
cd python/pypto/language/op && sed -n '795,820p' tensor_ops.pyRepository: hw-native-sys/pypto
Length of output: 957
🏁 Script executed:
# Look at the full scatter function implementation
cd python/pypto/language/op && grep -n "def scatter" tensor_ops.pyRepository: hw-native-sys/pypto
Length of output: 104
🏁 Script executed:
# Check the C++ scatter implementation for reduce handling
cat src/ir/op/tensor_ops/scatter.cpp | head -150Repository: hw-native-sys/pypto
Length of output: 5645
🏁 Script executed:
# Search for where the reduce parameter is used in C++ code
rg -B5 -A10 "reduce" --type cpp src/ir/op/tensor_ops/scatter.cppRepository: hw-native-sys/pypto
Length of output: 661
🏁 Script executed:
# Search for reduce parameter usage in tests or examples
rg -n "reduce" --type python python/pypto | grep -E "(test|example)" | head -20Repository: hw-native-sys/pypto
Length of output: 93
🏁 Script executed:
# Search for scatter tests
fd -e py "test.*scatter" python/Repository: hw-native-sys/pypto
Length of output: 45
🏁 Script executed:
# Look for where scatter_ is actually called with reduce parameter
rg -B2 -A2 'scatter_.*reduce' --type python python/Repository: hw-native-sys/pypto
Length of output: 93
🏁 Script executed:
# Search for scatter tests and reduce usage
fd -type f -name "*scatter*" python/Repository: hw-native-sys/pypto
Length of output: 233
🏁 Script executed:
# Search for reduce parameter usage in Python files
rg "reduce\s*=" --type py python/ | head -20Repository: hw-native-sys/pypto
Length of output: 396
🏁 Script executed:
# Look at the actual backend to understand what reduce modes are supported
rg -n "reduce" src/ | grep -E "(add|multiply)" | head -20Repository: hw-native-sys/pypto
Length of output: 541
🏁 Script executed:
# Check if there are any comments or documentation about valid reduce modes
rg -B3 -A3 "reduce.*add.*multiply" --type py --type cpp --type txtRepository: hw-native-sys/pypto
Length of output: 1343
Update docstring to include all valid reduce values.
The Python docstring lists only "add" and "multiply", but the C++ backend (src/ir/transforms/op_conversion_registry.cpp:546) validates against three modes: "none", "add", and "multiply". Update the docstring to reflect the complete set of valid values.
- reduce: Optional reduce mode ("add" or "multiply")
+ reduce: Optional reduce mode ("none", "add", or "multiply")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| reduce: Optional reduce mode ("add" or "multiply") | |
| reduce: Optional reduce mode ("none", "add", or "multiply") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@python/pypto/language/op/tensor_ops.py` at line 806, Update the docstring for
the parameter named reduce in python/pypto/language/op/tensor_ops.py to list all
valid values used by the backend — "none", "add", and "multiply" — so it matches
the validation in the C++ op conversion registry; locate the docstring block
that documents reduce and replace the current quoted values ("add", "multiply")
with the full set ("none", "add", "multiply") and ensure the description clearly
states what each mode means if present.
| // Helper: ensure an index element SSA value has `index` type. | ||
| // If the expression is a non-index integer (e.g. i32 from tile.read on an | ||
| // INT32 tile), emit arith.index_cast to convert it. | ||
| auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string { | ||
| if (auto var = ir::As<ir::Var>(expr)) { | ||
| return codegen.EmitCastToIndex(var, ssa); | ||
| } | ||
| return ssa; | ||
| }; | ||
|
|
||
| // For each dimension i, compute: index[i] * (shape[i+1] * shape[i+2] * ... * shape[rank-1]) | ||
| // then sum all terms with arith.addi. | ||
| std::string accumulator; | ||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| if (i > 0) idx_oss << " + "; | ||
| idx_oss << codegen.GetExprAsCode(indices[i]); | ||
| std::string term = ensure_index(indices[i], codegen.GetExprAsCode(indices[i])); | ||
| // Multiply by each trailing dimension size | ||
| for (size_t j = i + 1; j < shape.size(); ++j) { | ||
| idx_oss << " * " << codegen.GetExprAsCode(shape[j]); | ||
| std::string dim = codegen.GetExprAsCode(shape[j]); | ||
| std::string tmp = codegen.NewTemp(); | ||
| codegen.Emit(tmp + " = arith.muli " + term + ", " + dim + " : index"); |
There was a problem hiding this comment.
Cast all non-index scalar index expressions before the flat-offset math.
Line 733 only normalizes plain ir::Var. If a tuple element is an IterArg, an integer ConstInt, or a computed scalar like idx32 + 1, term stays i32/i64 and the emitted arith.muli / arith.addi ... : index becomes ill-typed.
🔧 Suggested fix
- auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
- if (auto var = ir::As<ir::Var>(expr)) {
- return codegen.EmitCastToIndex(var, ssa);
- }
- return ssa;
- };
+ auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
+ auto scalar_type = As<ScalarType>(expr->GetType());
+ INTERNAL_CHECK(scalar_type) << "flat index expression must be scalar";
+ if (scalar_type->dtype_ == DataType::INDEX) {
+ return ssa;
+ }
+ CHECK(!scalar_type->dtype_.IsFloat()) << "flat index expression must be integer/index typed";
+ std::string idx_ssa = codegen.NewTemp();
+ codegen.Emit(idx_ssa + " = arith.index_cast " + ssa + " : " +
+ codegen.GetTypeString(scalar_type->dtype_) + " to index");
+ return idx_ssa;
+ };🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/backend/common/pto_ops_common.cpp` around lines 729 - 748, The code only
casts ir::Var in the ensure_index lambda, so non-Var index expressions
(ConstInt, IterArg, or compound Expr like add) keep their integer scalar type
and produce ill-typed arith.muli/arith.addi ... : index ops; update ensure_index
(used above the loop that computes term and in the flat-offset math) to inspect
the expression's type or kind and call codegen.EmitCastToIndex for any
expression that is not already of index type (not just ir::Var), ensuring term
is always the casted SSA string; use the same pattern around
codegen.GetExprAsCode(indices[i]) so all terms used in the subsequent
arith.muli/arith.addi emissions are index-typed.
- Add int src support to scatter_ signature (docs, DSL, IR validation) - Add reduce kwarg forwarding in scatter_ DSL and IR layers - Add scalar dtype pair validation (float src → float input, int src → int input) - Extract duplicate ND→2D flatten logic into FlattenNdIndicesToTwoD helper - Add barrier argument validation in MakeSystemBarrierCodegenPTO - Extract _scatter_handler from _register_ops to fix ruff PLR0915 - Fix scatter UT: add Out params to Before programs matching pass output - Fix InOut param test: align assertion with post-rebase pass behavior - Apply clang-format and ruff-format fixes
Add a minimal 2x8 scatter test case with unique src values for easy tracing. Add diagnostic prints to all scatter test methods showing dim, index, src, and expected values.
|
|
||
| static std::string MakeSystemBarrierCodegenPTO(const std::string& pipe_name, const CallPtr& op, | ||
| codegen::CodegenBase& codegen_base) { | ||
| CHECK(op->args_.empty()) << "system.barrier_" << pipe_name << " expects 0 arguments, got " |
There was a problem hiding this comment.
为什么要插入barrier,这个工作应该是ptoas做的
Summary
Implement
tensor.scatter_following PyTorchtorch.Tensor.scatter_semantics (issue #677).scatter.cpp): Registertensor.scatter_with type deduction for input/index/src validationop_conversion_registry.cpp): Decompose into nestedscf.forloops withtile.read/tile.write(scalartgetval/tsetval)system.bar_v/bar_m/bar_allbarrier codegen, fixComputeFlatOffsetPTOfor SSA index_cast, fixConstIntdtype dispatch (INT32 vs INDEX)flatten_tile_nd_to_2d_pass.cpp): Supporttile.read/tile.writeindex flattening for >2D tilesconvert_tensor_to_tile_ops_pass.cpp): Phase 3 detects unreferencedOutparams and reuses them for auto-insertedtile.store, enabling scatter's implicit output patternpl.scatter_(input, dim, index, src)with full parameter validationDesign decisions
tgetval/tsetvalloops instead of vectorized tile ops, due to RowMajor 32-byte alignment constraints and TINSERT limitationsbar_allbefore/after nested loops +bar_vafter eachtsetval— required because PTOASPTOInsertSynccan't auto-sync these patternsscatter_returns the input tile directly (no copy), with alias tracking in the tensor→tile mappl.Outparam without explicitpl.store(); the pass auto-reuses it fortile.storeinsertionTest plan
pytest tests/ut/ir/operators/test_tensor_ops.py -k scatterpytest tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py -k scatterpytest tests/st/runtime/test_scatter.py -v --platform=a2a3cmake --build build --parallelpasses cleanly