Skip to content

feat(op): implement tensor.scatter_ element-level scatter#898

Open
Little-oil wants to merge 3 commits intohw-native-sys:mainfrom
Little-oil:scatter
Open

feat(op): implement tensor.scatter_ element-level scatter#898
Little-oil wants to merge 3 commits intohw-native-sys:mainfrom
Little-oil:scatter

Conversation

@Little-oil
Copy link
Copy Markdown
Contributor

Summary

Implement tensor.scatter_ following PyTorch torch.Tensor.scatter_ semantics (issue #677).

  • IR op definition (scatter.cpp): Register tensor.scatter_ with type deduction for input/index/src validation
  • Tensor→Tile conversion (op_conversion_registry.cpp): Decompose into nested scf.for loops with tile.read/tile.write (scalar tgetval/tsetval)
  • PTO codegen: Add system.bar_v/bar_m/bar_all barrier codegen, fix ComputeFlatOffsetPTO for SSA index_cast, fix ConstInt dtype dispatch (INT32 vs INDEX)
  • ND→2D flattening (flatten_tile_nd_to_2d_pass.cpp): Support tile.read/tile.write index flattening for >2D tiles
  • Unused Out param reuse (convert_tensor_to_tile_ops_pass.cpp): Phase 3 detects unreferenced Out params and reuses them for auto-inserted tile.store, enabling scatter's implicit output pattern
  • Python DSL/IR layers: pl.scatter_(input, dim, index, src) with full parameter validation
  • Tests: 7 UT (operator + transform) + 4 ST hardware tests (2D dim=0/1, scalar src, 3D dim=2)

Design decisions

  1. ForStmt approach: Uses scalar tgetval/tsetval loops instead of vectorized tile ops, due to RowMajor 32-byte alignment constraints and TINSERT limitations
  2. Hardware sync: bar_all before/after nested loops + bar_v after each tsetval — required because PTOAS PTOInsertSync can't auto-sync these patterns
  3. In-place alias: scatter_ returns the input tile directly (no copy), with alias tracking in the tensor→tile map
  4. Implicit output: InCore declares pl.Out param without explicit pl.store(); the pass auto-reuses it for tile.store insertion

Test plan

  • UT: pytest tests/ut/ir/operators/test_tensor_ops.py -k scatter
  • UT: pytest tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py -k scatter
  • ST: pytest tests/st/runtime/test_scatter.py -v --platform=a2a3
  • Build: cmake --build build --parallel passes cleanly

…w-native-sys#677)

Add scatter_ op following PyTorch torch.Tensor.scatter_ semantics.
Decomposes into nested ForStmt + tile.read/tile.write loops in the
ConvertTensorToTileOps pass. Includes ND→2D index flattening, PTO
barrier codegen, ConstInt dtype dispatch, and unused Out param reuse
for implicit tile.store insertion.
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds a new tensor operation tensor.scatter_ end-to-end: docs, Python API, IR registration and type deduction, tensor→tile lowering, PTO codegen/backend updates, and unit + runtime tests.

Changes

Cohort / File(s) Summary
Documentation
docs/en/user/02-operation_reference.md, docs/zh-cn/user/02-operation_reference.md
Added scatter_ operation docs and signature (input: Tensor, dim: int, index: Tensor, src: Tensor | float | int) -> Tensor.
Python Language API Layer
python/pypto/language/__init__.py, python/pypto/language/op/__init__.py, python/pypto/language/op/tensor_ops.py
Exported new scatter_ wrapper at language level; accepts tensor or scalar src and optional reduce.
Python IR Layer / Debug
python/pypto/ir/op/tensor_ops.py, python/pypto/debug/torch_codegen.py
Added IR builder scatter_ with positional/kw parsing and scalar normalization; added debug codegen handler for tensor.scatter_.
IR Op Definition
src/ir/op/tensor_ops/scatter.cpp
Registered tensor.scatter_ op and implemented DeduceTensorScatterType enforcing rank/dtype/index/dim constraints and reduce/dim attributes.
Op Conversion Registry
src/ir/transforms/op_conversion_registry.cpp
Added custom lowering for tensor.scatter_ to nested ForStmt loops using tile.read/tile.write, optional reduce logic, and inserted system.bar_* synchronization.
Tensor→Tile Transform
src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp
Added direct var-alias remapping to avoid redundant assigns; improved Phase‑3 output parameter reuse (prefer existing InOut/unused Out before creating new Out).
ND→2D Flatten Transform
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
Added ND→2D index flattening helper, rewrite for tile.read/tile.write (>2D), and offset padding for ND tile.store.
PTO Codegen (const int handling)
src/codegen/pto/pto_codegen.cpp, src/codegen/pto/pto_scalar_expr_codegen.cpp
ConstInt emission now honors integer dtype (INT32→i32, INT64→i64, other ints→dtype-specific arith.constant) and updates type annotations accordingly.
PTO Backend Ops
src/backend/common/pto_ops_common.cpp
Refactored flat-offset emission to MLIR SSA arithmetic, added index-cast helper, and registered system.bar_v/m/all barrier ops.
Build
CMakeLists.txt
Added src/ir/op/tensor_ops/scatter.cpp to project sources.
Unit Tests
tests/ut/ir/operators/test_tensor_ops.py, tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py
Added IR-level tests for tensor.scatter_ (2D/3D, negative dim, error paths) and conversion tests asserting tile.read/tile.write lowering and Out/InOut reuse.
Runtime Tests
tests/st/runtime/test_scatter.py
Added end-to-end runtime tests covering 2D/3D and scalar/tensor src cases with PyTorch-based expected outputs.

Sequence Diagram

sequenceDiagram
    participant API as Language API
    participant IR as IR Builder
    participant Type as Type Deduction
    participant Conv as Tensor→Tile Conversion
    participant CG as PTO Codegen
    participant BE as PTO Backend

    API->>IR: scatter_(input, dim, index, src, reduce?)
    IR->>IR: parse args, normalize dim, wrap scalar src
    IR->>Type: create `tensor.scatter_` Call
    Type->>Type: validate ranks, index dtype, dim range -> TensorType
    Conv->>Conv: detect `tensor.scatter_`, load index/src as tiles if needed
    Conv->>Conv: build nested ForStmt loops, read index/src, compute write indices
    Conv->>Conv: tile.write (with optional reduce using tile.read)
    Conv->>BE: insert system.bar_v per write and system.bar_all around loop
    CG->>CG: emit ConstInt constants per integer dtype (i32/i64/typed arith.constant)
    CG->>BE: emit MLIR SSA ops
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~80 minutes

Possibly related issues

  • [New Op] scatter_ #677: Adds the same scatter_ tensor operation (signature and semantics); likely the same feature request.

Possibly related PRs

Suggested reviewers

  • Hzfengsy
  • lyfne123

🐰 A scatter of joy through the pipeline hops,
Indexes set, the kernel never stops,
Loops and barriers hum in sync so bright,
Values land exactly where they’re right. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description is comprehensive and directly related to the changeset, covering IR implementation, tensor-to-tile conversion, codegen updates, ND-to-2D flattening, parameter reuse, Python DSL additions, and test coverage.
Title check ✅ Passed The PR title 'feat(op): implement tensor.scatter_ element-level scatter' clearly and specifically describes the main change: implementing a new tensor scatter operation. It is concise, uses the standard conventional commit format, and directly relates to the primary objective stated in the PR description.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the scatter_ operation, which implements element-level scatter semantics similar to PyTorch. The implementation includes the core IR operation, Python DSL bindings, and a conversion pass that lowers the tensor-level operation into nested loops with tile.read and tile.write. Additionally, the PR improves N-D tile handling in the flattening pass and enhances PTO codegen for integer constants and system barriers. Feedback focuses on ensuring the reduce argument is correctly propagated through the Python IR layer, registering it as an official attribute in the C++ op definition, and refactoring duplicated index-flattening logic into a helper method.

Comment on lines +945 to +1008
def scatter_(
input: Expr,
*args: Expr | int | float,
dim: int | Expr | None = None,
index: Expr | None = None,
src: Expr | float | int | None = None,
span: Span | None = None,
) -> Call:
"""Element-level scatter into tensor along a dimension.

For each position (i₀,…,iₙ) in index, sets:
input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]

Follows PyTorch ``torch.Tensor.scatter_`` semantics.

Accepts call forms:
- scatter_(input, dim, index, src)
- scatter_(input, dim, index, src=1.0)

Args:
input: Destination tensor (N-D).
dim: Dimension along which to scatter.
index: Index tensor (same rank as input, integer dtype).
src: Source tensor (same shape as index) or scalar value.
span: Optional source span for debugging (auto-captured if not provided).

Returns:
Call expression returning the updated input tensor.
"""
if len(args) == 3 and dim is None and index is None and src is None:
dim, index, src = args
elif len(args) == 2 and dim is not None and index is None and src is None:
index, src = args
elif len(args) == 1 and dim is None and index is not None and src is not None:
dim = args[0]
elif len(args) != 0:
raise TypeError(
"scatter_ expects (input, dim, index, src), "
"(input, index, src, dim=...), or (input, dim, index=..., src=...)"
)

if dim is None or index is None or src is None:
raise TypeError("scatter_ requires input, dim, index, and src")

actual_span = _get_span_or_capture(span)
if isinstance(dim, ConstInt):
dim_val = int(dim.value)
elif isinstance(dim, int):
dim_val = dim
else:
raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")

if not isinstance(index, Expr):
raise TypeError(f"index must be Expr, got {type(index)}")

# src can be Expr or scalar (int/float → ConstFloat)
if isinstance(src, (int, float)):
src = ConstFloat(float(src), DataType.FP32, actual_span)
elif not isinstance(src, Expr):
raise TypeError(f"src must be Expr or scalar, got {type(src)}")

op_args: list[Expr] = [input, index, src]
kwargs: dict[str, Any] = {"dim": dim_val}
return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The language layer pl.scatter_ calls this function with a reduce keyword argument, but this function's signature doesn't accept it. This will lead to a TypeError at runtime. The reduce argument should be added to the signature and passed to create_op_call. Additionally, ensure that user-provided arguments are validated at this level to provide clear error messages.

def scatter_(
    input: Expr,
    *args: Expr | int | float,
    dim: int | Expr | None = None,
    index: Expr | None = None,
    src: Expr | float | int | None = None,
    reduce: str | None = None,
    span: Span | None = None,
) -> Call:
    """Element-level scatter into tensor along a dimension.

    For each position (i₀,…,iₙ) in index, sets:
      input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]

    Follows PyTorch ``torch.Tensor.scatter_`` semantics.

    Accepts call forms:
    - scatter_(input, dim, index, src)
    - scatter_(input, dim, index, src=1.0)

    Args:
        input: Destination tensor (N-D).
        dim: Dimension along which to scatter.
        index: Index tensor (same rank as input, integer dtype).
        src: Source tensor (same shape as index) or scalar value.
        reduce: Optional reduction mode ("add" or "multiply").
        span: Optional source span for debugging (auto-captured if not provided).

    Returns:
        Call expression returning the updated input tensor.
    """
    if len(args) == 3 and dim is None and index is None and src is None:
        dim, index, src = args
    elif len(args) == 2 and dim is not None and index is None and src is None:
        index, src = args
    elif len(args) == 1 and dim is None and index is not None and src is not None:
        dim = args[0]
    elif len(args) != 0:
        raise TypeError(
            "scatter_ expects (input, dim, index, src), "
            "(input, index, src, dim=...), or (input, dim, index=..., src=...)"
        )

    if dim is None or index is None or src is None:
        raise TypeError("scatter_ requires input, dim, index, and src")

    actual_span = _get_span_or_capture(span)
    if isinstance(dim, ConstInt):
        dim_val = int(dim.value)
    elif isinstance(dim, int):
        dim_val = dim
    else:
        raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")

    if not isinstance(index, Expr):
        raise TypeError(f"index must be Expr, got {type(index)}")

    # src can be Expr or scalar (int/float → ConstFloat)
    if isinstance(src, (int, float)):
        src = ConstFloat(float(src), DataType.FP32, actual_span)
    elif not isinstance(src, Expr):
        raise TypeError(f"src must be Expr or scalar, got {type(src)}")

    op_args: list[Expr] = [input, index, src]
    kwargs: dict[str, Any] = {"dim": dim_val}
    if reduce is not None:
        kwargs["reduce"] = reduce
    return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)
References
  1. Validate user-provided arguments for DSL functions at the parser level to provide early and clear error messages, rather than relying solely on backend C++ validation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: pl.scatter_() now passes reduce kwarg through both the IR and language layers.

Comment on lines +91 to +106
REGISTER_OP("tensor.scatter_")
.set_op_category("TensorOp")
.set_description(
"Element-level scatter: write src values into input at positions given by index along dim. "
"For each element position (i₀,…,iₙ) in index, sets "
"input[i₀]…[i_{d-1}][index[i₀…iₙ]][i_{d+1}]…[iₙ] = src[i₀…iₙ]. "
"Supports arbitrary rank and any valid dim ∈ [-rank, rank). "
"src can be a tensor (same shape as index) or a scalar value.")
.add_argument("input", "Destination tensor (N-D)")
.add_argument("index", "Index tensor (N-D, same rank as input) of integer dtype")
.add_argument("src", "Source tensor (same shape as index) or scalar value")
.set_attr<int>("dim")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTensorScatterType(args, kwargs);
});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reduce functionality is a core part of tensor.scatter_, as implemented in the conversion pass and exposed in the Python layers. For consistency with the dim attribute and for better IR clarity, reduce should be registered as an optional string attribute of the op. This makes the op's definition self-contained and improves discoverability.

REGISTER_OP("tensor.scatter_")
    .set_op_category("TensorOp")
    .set_description(
        "Element-level scatter: write src values into input at positions given by index along dim. "
        "For each element position (i₀,…,iₙ) in index, sets "
        "input[i₀]…[i_{d-1}][index[i₀…iₙ]][i_{d+1}]…[iₙ] = src[i₀…iₙ]. "
        "Supports arbitrary rank and any valid dim ∈ [-rank, rank). "
        "src can be a tensor (same shape as index) or a scalar value. "
        "Also supports 'add' and 'multiply' reduction modes.")
    .add_argument("input", "Destination tensor (N-D)")
    .add_argument("index", "Index tensor (N-D, same rank as input) of integer dtype")
    .add_argument("src", "Source tensor (same shape as index) or scalar value")
    .set_attr<int>("dim")
    .set_attr<std::string>("reduce", /*is_optional=*/true, /*default_value=*/"none")
    .f_deduce_type([](const std::vector<ExprPtr>& args,
                      const std::vector<std::pair<std::string, std::any>>& kwargs) {
      return DeduceTensorScatterType(args, kwargs);
    });

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: scatter.cpp now stores the reduce attribute via set_attr.

Comment on lines +651 to +701
// ---- tile.read/tile.write on >2D tiles: flatten ND indices to 2D ----
// tile.read(tile, (i0, i1, ..., in)) → tile.read(tile_2d, (merged_row, col))
// where merged_row = i0 * d1 * d2 * ... * d_{n-2} + i1 * d2 * ... + i_{n-2}
// and col = i_{n-1}
// tile.write(tile, (i0, ..., in), val) → tile.write(tile_2d, (merged_row, col), val)
if (op_name == "tile.read" || op_name == "tile.write") {
auto orig_tile_type = As<TileType>(call->args_[0]->GetType());
if (orig_tile_type && IsNdTile(orig_tile_type)) {
std::vector<ExprPtr> new_args;
new_args.reserve(call->args_.size());
// args[0]: tile (substitute)
new_args.push_back(Substitute(call->args_[0], ctx.var_map));
// args[1]: indices tuple — flatten from ND to 2D
auto idx_tuple = As<MakeTuple>(call->args_[1]);
INTERNAL_CHECK(idx_tuple) << "tile.read/tile.write indices must be MakeTuple";
const auto& nd_shape = orig_tile_type->shape_;
const size_t rank = nd_shape.size();
// Compute merged_row = sum of i_k * product(nd_shape[k+1..rank-2]) for k in [0..rank-2)
ExprPtr merged_row;
for (size_t k = 0; k + 1 < rank; ++k) {
ExprPtr term = Substitute(idx_tuple->elements_[k], ctx.var_map);
// Multiply by trailing dimensions (excluding last)
for (size_t j = k + 1; j + 1 < rank; ++j) {
term = MakeMul(term, nd_shape[j], span);
}
merged_row = merged_row ? MakeAdd(merged_row, term, span) : term;
}
ExprPtr col = Substitute(idx_tuple->elements_[rank - 1], ctx.var_map);
new_args.push_back(std::make_shared<MakeTuple>(
std::vector<ExprPtr>{merged_row, col}, span));
// Remaining args (e.g., value for tile.write)
for (size_t i = 2; i < call->args_.size(); ++i) {
new_args.push_back(Substitute(call->args_[i], ctx.var_map));
}
auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span);
if (op_name == "tile.read") {
// tile.read returns scalar — assign to var
auto new_var = std::make_shared<Var>(
assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_));
ctx.Insert(assign->var_, new_var);
} else {
// tile.write returns tile — assign to var and update mapping
auto new_var = std::make_shared<Var>(
assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_));
ctx.Insert(assign->var_, new_var);
}
continue;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for flattening N-D indices to 2D for tile.read and tile.write is duplicated here and in the EvalStmt handler for tile.write around line 411. This duplicated logic for calculating merged_row should be extracted into a private helper method to improve maintainability and avoid future inconsistencies.

References
  1. Extract duplicate logic into a private helper method to improve maintainability and avoid future inconsistencies.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: extracted FlattenNdIndicesToTwoD helper to eliminate duplicated code.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (1)
tests/st/runtime/test_scatter.py (1)

268-319: Please add one 3D non-last-dim runtime case.

This is the only ND hardware test, and dim=2 keeps the scattered index in the final column. The new ND→2D flattening logic is riskier when the scattered dimension participates in merged_row (dim=0 or dim=1 on 3D), so one case there would exercise the new math end to end.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/runtime/test_scatter.py` around lines 268 - 319, Add a sibling 3D
test where the scatter dimension is not the last axis (e.g., dim=1) to exercise
ND→2D flattening when the scattered dim participates in merged_row: copy
Scatter3dDim2Program and Scatter3dDim2TestCase to create Scatter3dDim1Program
and Scatter3dDim1TestCase, update the kernel/orchestrator to call
pl.scatter_(..., dim=1, index=index, src=src), update compute_expected to use
expected.scatter_(1, tensors["index"].long(), tensors["src"].float()), and
ensure the test uses a proper index init constant (e.g., _IDX_3D_DIM1) with
appropriate values selecting positions in the middle dimension.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/en/user/02-operation_reference.md`:
- Line 46: The docs for scatter_ list src as `Tensor | float` but the Python API
accepts integer scalars too; update the operation signature and descriptive text
for `scatter_` so the scalar case includes integers (e.g., `Tensor | int |
float`) and clarify “src can be a tensor (same shape as `index`) or a scalar
(int or float)”; make this change on the `scatter_` entry so the documented API
surface matches the actual `scatter_` Python entry points.

In `@docs/zh-cn/user/02-operation_reference.md`:
- Line 43: The Chinese doc for scatter_ currently lists src as `Tensor | float`
but the implementation accepts integer scalars too; update the signature in the
`scatter_` entry to `Tensor | int | float` (matching the English reference and
the implementation) so the zh-CN scalar type is in sync with the actual API, and
ensure the descriptive sentence mentions that `src` may be a tensor or a scalar
(int or float).

In `@python/pypto/ir/op/tensor_ops.py`:
- Around line 1000-1004: The current coercion in tensor_ops.py turns integer
scalar inputs into ConstFloat(DataType.FP32) which is wrong; change the handling
for src so that ints become ConstInt (e.g. ConstInt(int(src), appropriate
integer dtype, actual_span)) instead of ConstFloat, leave floats as ConstFloat,
and keep the isinstance(src, Expr) check; after wrapping, validate and/or cast
the resulting ConstInt/ConstFloat against the destination/input tensor dtype
(use the input's dtype validation codepath already used elsewhere) so integer
literals are preserved and properly converted to the tensor's dtype rather than
always becoming FP32; update references to ConstFloat, ConstInt, DataType.FP32
and the src variable to implement this logic.
- Around line 945-952: The wrapper forwards a reduce=... kwarg into the IR
builder but the IR builder function scatter_ (in
python/pypto/ir/op/tensor_ops.py) does not accept it, causing a TypeError; fix
by adding a reduce parameter to scatter_ (e.g., reduce: Expr | str | None =
None) and propagate/validate it into the Call construction (or explicitly coerce
allowed values like "sum"/"mean"/"prod"), or alternatively modify the language
wrapper to stop forwarding reduce; update the scatter_ signature and any
internal packing/args handling so reduce is accepted, validated, and passed
through to whatever Call or node creation uses the other params (input, dim,
index, src, span) to keep pl.scatter_ working end-to-end.

In `@python/pypto/language/op/tensor_ops.py`:
- Around line 785-813: The scatter_ implementation currently only unwraps Tensor
sources, so when src is a DSL Scalar wrapper it gets forwarded as the wrapper
instead of the underlying IR expression; update scatter_ (function scatter_) to
detect and unwrap Scalar as well (i.e., if isinstance(src, Tensor) use
src.unwrap(), elif isinstance(src, Scalar) use src.unwrap(), else pass the raw
Python scalar) before calling _ir_ops.scatter_, and ensure the Scalar symbol is
imported where needed.

In `@src/backend/common/pto_ops_common.cpp`:
- Around line 959-965: MakeSystemBarrierCodegenPTO currently ignores any
operands on the `op` and always emits a barrier; add a fast-fail guard that
validates the call is zero-argument and reports a clear error if not. Inside
MakeSystemBarrierCodegenPTO (using the `CallPtr& op`), check the
operand/argument count (e.g., `op->args().size()` or `op->numArgs()` /
`op->operands().size()` depending on your CallPtr API) and if it is non-zero
throw or log a fatal error (e.g., throw std::runtime_error with a message that
includes the op identity or text) instead of emitting the barrier; keep the
existing emit (`codegen.Emit("pto.barrier `#pto.pipe`<" + pipe_name + ">");`) only
when the argument count is zero.

In `@src/codegen/pto/pto_codegen.cpp`:
- Around line 909-915: GetExprAsCode() currently maps only INT32 ConstInt to
GetOrEmitI32Constant and sends every other integer to GetIndexConstant, causing
MLIR type mismatches versus GetExprTypeAnnotation(); update the ConstInt branch
in GetExprAsCode() to mirror the full typed-constant handling used in the
ConstInt visitor: inspect const_int->dtype() and dispatch to the appropriate
emitter (e.g., GetOrEmitI32Constant, GetOrEmitI64Constant,
GetOrEmitUI64Constant, GetOrEmitI8Constant, GetOrEmitUI8Constant, etc.) for each
DataType rather than defaulting to GetIndexConstant so the emitted constant type
matches the annotated MLIR type.

In `@src/ir/op/tensor_ops/scatter.cpp`:
- Around line 55-85: DeduceTensorScatterType currently accepts many invalid
src/index combos; update it (inside DeduceTensorScatterType near use of args,
index_type, src_is_scalar, src_type, kwargs and dim handling) to additionally
validate: for non-scalar src, require src_type->shape_.size() == rank (already
checked) and enforce that for every axis k != dim the src_type->shape_[k] ==
index_type->shape_[k] and also src_type->shape_[k] <= input_type->shape_[k]; for
the scatter axis require src_type->shape_[dim] <= input_type->shape_[dim]
(reject if larger); for scalar src (src_is_scalar) check the literal dtype is
implicitly convertible/writable into input_type->dtype_ (use the same dtype
comparison/convertibility logic used elsewhere) and reject incompatible scalar
dtypes; keep using the existing CHECK macro and error wording consistent with
other messages.

In `@src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`:
- Around line 1763-1783: The code only checks unused_out_params.begin() for a
matching Out param, causing missed compatible params later in the set; change
the logic in the block handling unused_out_params so you iterate through all
entries in unused_out_params (e.g., a for-loop over unused_out_params) and for
each candidate verify dtype and shape equality (use As<TensorType>(*) and
As<ConstInt> on dimensions, comparing values as in the existing checks); when
you find the first matching entry, set out_param, set is_existing_param = true
and erase that entry from unused_out_params, then break out of the loop.
- Around line 723-733: The shortcut only handles As<Var>(new_result) but misses
when new_result is an IterArg (e.g. scatter_ returning the loop-carried
IterArg), so detect IterArg aliases as well: after the As<Var> branch, check if
new_result can be cast to IterArg (As<IterArg>(new_result)) and record the same
aliasing in var_remap_ (mapping the original op->var_ to that IterArg alias),
then remove the redundant assignment from stmts and return the prologue SeqStmts
exactly as done for the Var case so the synthetic assignment is suppressed for
IterArg results too.

---

Nitpick comments:
In `@tests/st/runtime/test_scatter.py`:
- Around line 268-319: Add a sibling 3D test where the scatter dimension is not
the last axis (e.g., dim=1) to exercise ND→2D flattening when the scattered dim
participates in merged_row: copy Scatter3dDim2Program and Scatter3dDim2TestCase
to create Scatter3dDim1Program and Scatter3dDim1TestCase, update the
kernel/orchestrator to call pl.scatter_(..., dim=1, index=index, src=src),
update compute_expected to use expected.scatter_(1, tensors["index"].long(),
tensors["src"].float()), and ensure the test uses a proper index init constant
(e.g., _IDX_3D_DIM1) with appropriate values selecting positions in the middle
dimension.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2cd416c7-1b1f-432b-a1f9-209749e81ef3

📥 Commits

Reviewing files that changed from the base of the PR and between b567e2a and 8bcf11e.

📒 Files selected for processing (18)
  • CMakeLists.txt
  • docs/en/user/02-operation_reference.md
  • docs/zh-cn/user/02-operation_reference.md
  • python/pypto/debug/torch_codegen.py
  • python/pypto/ir/op/tensor_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/tensor_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/codegen/pto/pto_codegen.cpp
  • src/codegen/pto/pto_scalar_expr_codegen.cpp
  • src/ir/op/tensor_ops/scatter.cpp
  • src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
  • src/ir/transforms/op_conversion_registry.cpp
  • tests/st/runtime/test_scatter.py
  • tests/ut/ir/operators/test_tensor_ops.py
  • tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py

Comment on lines +55 to +85
auto index_type = As<TensorType>(args[1]->GetType());
CHECK(index_type) << "tensor.scatter_: index must be TensorType, got " << args[1]->GetType()->TypeName();
CHECK(index_type->shape_.size() == rank)
<< "tensor.scatter_: index rank (" << index_type->shape_.size() << ") must match input rank (" << rank
<< ")";
CHECK(index_type->dtype_.IsInt()) << "tensor.scatter_: index dtype must be integer, got "
<< index_type->dtype_.ToString();

// src can be TensorType or scalar (ConstFloat / ConstInt)
bool src_is_scalar = As<ConstFloat>(args[2]) || As<ConstInt>(args[2]);
if (!src_is_scalar) {
auto src_type = As<TensorType>(args[2]->GetType());
CHECK(src_type) << "tensor.scatter_: src must be TensorType or scalar, got "
<< args[2]->GetType()->TypeName();
CHECK(src_type->shape_.size() == rank)
<< "tensor.scatter_: src rank (" << src_type->shape_.size() << ") must match input rank (" << rank
<< ")";
CHECK(src_type->dtype_ == input_type->dtype_)
<< "tensor.scatter_: src dtype (" << src_type->dtype_.ToString() << ") must match input dtype ("
<< input_type->dtype_.ToString() << ")";
}

// Validate dim kwarg
for (const auto& [key, val] : kwargs) {
if (key == "dim") {
int dim_val = AnyCast<int>(val, "kwarg key: dim");
int irank = static_cast<int>(rank);
CHECK(dim_val >= -irank && dim_val < irank)
<< "tensor.scatter_: dim must be in [" << -irank << ", " << irank << ") for " << rank
<< "D input, got " << dim_val;
}
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject impossible scatter shapes and scalar/input dtype pairs here.

DeduceTensorScatterType currently accepts any tensor src with the right rank and any scalar literal at all. That means obviously invalid programs—e.g. index larger than input on a non-scatter axis, tensor src smaller than the index iteration space, or a scalar literal whose dtype cannot be written into input—pass type deduction and only fail later in the elementwise lowering/codegen path. Tighten those checks before returning the result TensorType.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/tensor_ops/scatter.cpp` around lines 55 - 85,
DeduceTensorScatterType currently accepts many invalid src/index combos; update
it (inside DeduceTensorScatterType near use of args, index_type, src_is_scalar,
src_type, kwargs and dim handling) to additionally validate: for non-scalar src,
require src_type->shape_.size() == rank (already checked) and enforce that for
every axis k != dim the src_type->shape_[k] == index_type->shape_[k] and also
src_type->shape_[k] <= input_type->shape_[k]; for the scatter axis require
src_type->shape_[dim] <= input_type->shape_[dim] (reject if larger); for scalar
src (src_is_scalar) check the literal dtype is implicitly convertible/writable
into input_type->dtype_ (use the same dtype comparison/convertibility logic used
elsewhere) and reject incompatible scalar dtypes; keep using the existing CHECK
macro and error wording consistent with other messages.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: added dtype pair validation in scatter.cpp — float src requires float input, int src requires int input.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

@coderabbit-user
Thanks for the quick update — I see you’ve added the scalar dtype pair validation (float→float, int→int). That covers the dtype part of the request.

Before we can mark this thread as addressed, can you also add the remaining shape validations in DeduceTensorScatterType?

  • For non-scalar src:
    • For every axis k ≠ dim: enforce src_type->shape_[k] == index_type->shape_[k] and src_type->shape_[k] <= input_type->shape_[k].
    • For the scatter axis dim: enforce src_type->shape_[dim] <= input_type->shape_[dim].
  • For scalar src: you’ve added dtype checks; that’s good. If your project has a central “is-writable-to” helper for dtype convertibility, using it here would be ideal; otherwise the float↔int disallow rule you added is acceptable.

If helpful, here’s a targeted patch you can adapt inside DeduceTensorScatterType right after you parse dim and have input_type/index_type/src_is_scalar:

// Normalize dim to [0, rank)
int dim_norm = 0;
for (const auto& [key, val] : kwargs) {
  if (key == "dim") {
    int d = AnyCast<int>(val, "kwarg key: dim");
    int irank = static_cast<int>(rank);
    CHECK(d >= -irank && d < irank) << "tensor.scatter_: dim must be in [" << -irank << ", " << irank
                                    << ") for " << rank << "D input, got " << d;
    dim_norm = d < 0 ? d + irank : d;
    break;
  }
}

// Additional shape checks
if (!src_is_scalar) {
  auto src_type = As<TensorType>(args[2]->GetType());
  // already CHECKed rank and dtype match earlier
  for (size_t k = 0; k < rank; ++k) {
    const auto& idx_dim = index_type->shape_[k];
    const auto& in_dim  = input_type->shape_[k];
    const auto& src_dim = src_type->shape_[k];

    if (static_cast<int>(k) != dim_norm) {
      CHECK(src_dim == idx_dim)
          << "tensor.scatter_: src shape must match index shape on non-scatter axes; "
          << "mismatch at axis " << k << " (src " << src_dim.ToString() << " vs index "
          << idx_dim.ToString() << ")";
      CHECK(src_dim <= in_dim)
          << "tensor.scatter_: src extent exceeds input extent at axis " << k << " ("
          << src_dim.ToString() << " > " << in_dim.ToString() << ")";
    } else {
      CHECK(src_dim <= in_dim)
          << "tensor.scatter_: src extent exceeds input extent on scatter axis " << dim_norm
          << " (" << src_dim.ToString() << " > " << in_dim.ToString() << ")";
    }
  }
} else {
  // scalar src: ensure scalar dtype is writable to input dtype
  if (As<ConstFloat>(args[2])) {
    CHECK(input_type->dtype_.IsFloat())
        << "tensor.scatter_: float scalar src requires float input dtype, got "
        << input_type->dtype_.ToString();
  } else if (As<ConstInt>(args[2])) {
    CHECK(input_type->dtype_.IsInt())
        << "tensor.scatter_: int scalar src requires int input dtype, got "
        << input_type->dtype_.ToString();
  }
}

Once these guards are in, invalid shapes will be rejected at type deduction instead of failing later in lowering/codegen. Let me know if you’d like me to expand this into a full diff for the file.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Locate scatter.cpp =="
fd -a --strip-cwd-prefix 'scatter.cpp'

echo
echo "== Show DeduceTensorScatterType signature and surrounding logic =="
rg -n -C4 --type=cpp 'DeduceTensorScatterType\\(' src/ir/op/tensor_ops/scatter.cpp || true

echo
echo "== Show checks around src/type handling (scalar vs tensor) =="
rg -n -C3 --type=cpp 'src_is_scalar|As<ConstFloat>|As<ConstInt>|src_type|dtype_' src/ir/op/tensor_ops/scatter.cpp || true

echo
echo "== Show dim kwarg handling =="
rg -n -C3 --type=cpp 'key == "dim"|dim_val' src/ir/op/tensor_ops/scatter.cpp || true

Length of output: 2740


@Little-oil, thanks — I've confirmed the scalar dtype checks are now in place (lines 75–83). That part is done ✅

However, the per-axis shape constraints for non-scalar src are still missing. The current code only checks that src rank matches and src dtype matches input, but does not verify the extents are compatible:

Axis Missing check
k ≠ dim src_type->shape_[k] == index_type->shape_[k]
k ≠ dim src_type->shape_[k] <= input_type->shape_[k]
k == dim src_type->shape_[k] <= input_type->shape_[k]

Without these, a caller can pass a src that is larger than input on the scatter axis (or mismatched with index on other axes) and the error won't surface until deep in the loop-based lowering in op_conversion_registry.cpp.

Here is a targeted patch to add inside the if (!src_is_scalar) block (after retrieving dim_norm):

// Normalize dim once so we can do per-axis checks
int dim_norm = 0;
for (const auto& [key, val] : kwargs) {
  if (key == "dim") {
    int d = AnyCast<int>(val, "kwarg key: dim");
    dim_norm = d < 0 ? d + static_cast<int>(rank) : d;
    break;
  }
}

if (!src_is_scalar) {
  // ... existing rank/dtype checks ...

  // Per-axis shape compatibility
  for (size_t k = 0; k < rank; ++k) {
    if (static_cast<int>(k) != dim_norm) {
      CHECK(src_type->shape_[k] == index_type->shape_[k])
          << "tensor.scatter_: src shape must equal index shape on non-scatter axis " << k
          << " (src=" << src_type->shape_[k].ToString()
          << ", index=" << index_type->shape_[k].ToString() << ")";
      CHECK(src_type->shape_[k] <= input_type->shape_[k])
          << "tensor.scatter_: src extent exceeds input extent on non-scatter axis " << k
          << " (src=" << src_type->shape_[k].ToString()
          << ", input=" << input_type->shape_[k].ToString() << ")";
    } else {
      CHECK(src_type->shape_[k] <= input_type->shape_[k])
          << "tensor.scatter_: src extent exceeds input extent on scatter axis " << dim_norm
          << " (src=" << src_type->shape_[k].ToString()
          << ", input=" << input_type->shape_[k].ToString() << ")";
    }
  }
}

Note: you'll want to move the dim_norm computation before the if (!src_is_scalar) block, and de-duplicate it with the existing dim range-check loop (lines 86–94) so dim is normalized and range-checked in one pass.


🧠 Learnings used
Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 431
File: src/ir/op/tensor_ops/broadcast.cpp:56-79
Timestamp: 2026-03-10T09:46:22.193Z
Learning: In `src/ir/op/tensor_ops/broadcast.cpp` and `src/ir/op/tile_ops/broadcast.cpp` (pypto project), the rank checks in `DeduceTensorRowExpandType`, `DeduceTensorColExpandType`, and their tile-level equivalents intentionally use `>= 2` (not `== 2`). This allows row_expand_mul and col_expand_mul to operate on higher-rank tensors (e.g., `[B, M, N]`), applying the broadcast over the last two dimensions. Do not flag this as a bug or suggest tightening to exact rank-2 checks.

Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/op/tensor_ops/matmul.cpp:133-138
Timestamp: 2026-04-01T08:56:12.231Z
Learning: In hw-native-sys/pypto `src/ir/op/tensor_ops/matmul.cpp`, `tensor.matmul` is strictly 2D-only and `tensor.batch_matmul` requires both operands to have rank >= 3. Mixed-rank broadcasting (e.g., 2D×3D) is intentionally out of scope and requires explicit unsqueeze by the user. Do not flag the dual rank >= 3 guards in `DeduceTensorBatchMatMulType` as a missing-broadcast-case bug.

Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/op_conversion_registry.cpp:307-312
Timestamp: 2026-04-02T02:12:12.162Z
Learning: In hw-native-sys/pypto `src/ir/transforms/op_conversion_registry.cpp` (around lines 307–312), the `tensor.batch_matmul` conversion lambda intentionally requires both `lhs_shape.size() >= 3` and `rhs_shape.size() >= 3`. This mirrors the rank >= 3 constraint in `DeduceTensorBatchMatMulType` in `src/ir/op/tensor_ops/matmul.cpp`. Mixed-rank 2D×3D broadcasting is explicitly out of scope for this PR; 2D inputs must use `tensor.matmul`. Do not flag these guards as an over-restriction of the batch matmul conversion.

Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 405
File: src/ir/op/sync_ops/cross_core.cpp:57-67
Timestamp: 2026-03-08T15:07:33.101Z
Learning: In `src/ir/op/sync_ops/cross_core.cpp` and `python/pypto/ir/op/system_ops.py` (pypto project), the zero-argument `system.tpop_from_aic` and `system.tpop_from_aiv` ops intentionally fall back to `UnknownType` via `create_op_call` when no `result_type`/`shape`/`dtype` is provided. The C++ `ExpandMixedKernel` pass emits these tpop calls without explicit types; the concrete result type is resolved from the LHS variable's type annotation during DSL parsing. Rejecting no-type calls at the Python IR layer would break the pass. The existing validation in `_resolve_tpop_type` (mutual exclusivity of `result_type` vs `shape`/`dtype`, both-or-neither for `shape`+`dtype`) is sufficient.

Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 834
File: src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp:455-503
Timestamp: 2026-04-01T16:17:03.789Z
Learning: In hw-native-sys/pypto `src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`, `MatmulSlicePatternCollector::targets_` (and the old `PreScanSliceMatmulPatterns`) intentionally stores only one `MatmulSliceInfo` per slice result var — if the same slice feeds multiple matmul consumers, the last recorded consumer wins. This is a pre-existing limitation carried over from `PreScanSliceMatmulPatterns`; do not flag the absence of a multi-consumer guard in `CollectMatmulOperands` as a regression introduced by the refactoring. A follow-up improvement should add a guard that falls back to generic Vec-space slice conversion when consumers are incompatible.

Learnt from: lyfne123
Repo: hw-native-sys/pypto PR: 732
File: src/codegen/pto/pto_scalar_expr_codegen.cpp:221-226
Timestamp: 2026-03-26T07:19:54.729Z
Learning: In hw-native-sys/pypto `src/codegen/pto/pto_scalar_expr_codegen.cpp`, the `And`/`Or`/`Xor`/`BitAnd`/`BitOr`/`BitXor` `VisitExpr_` visitors call `VisitBinaryArithExpr` with identical `int_op` and `float_op` MLIR op strings (e.g., `"arith.andi"` for both). This is intentional and safe: `MakeAnd`/`MakeOr` factory functions enforce `DataType::BOOL` on operands, and `MakeBitAnd`/`MakeBitOr`/`MakeBitXor` enforce integer types via `PromoteIntBinaryOperands`. Float operands structurally cannot reach these visitors. Do not flag the identical int/float op strings as a missing float-guard bug in codegen PRs.

Learnt from: zhangqi-chen
Repo: hw-native-sys/pypto PR: 833
File: src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp:1688-1716
Timestamp: 2026-04-01T14:12:07.565Z
Learning: In hw-native-sys/pypto `src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp`, within Phase 3a of `TransformIncoreFunction`, the `sink_candidates` loop over IfStmt return vars is guaranteed to never produce two candidates sharing the same `ifstmt_rv_index`. A function's return statement will never reference the same IfStmt return var in multiple positions with different iter-arg mappings. Do not flag the absence of a duplicate-index guard in this loop as a correctness issue.

Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp:610-625
Timestamp: 2026-04-01T08:56:31.235Z
Learning: In hw-native-sys/pypto `src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`, the `DetectDirectStore` fusion in `LowerBatchMatmul` does not need an explicit single-use check on the `tile.batch_matmul` result var. `ConvertTensorToTileOps` always emits `tile.batch_matmul → tile.store` as consecutive statements, so the batch_matmul result is invariably single-use in generated IR. In the fused path, `assign->var_` is intentionally left unmapped in `ctx.var_map` (only the store var is remapped via `ctx.Insert(lowering.store_orig_var, lowering.store_result_var)`), which is safe because no later references to the batch_matmul result exist. Do not flag the absence of a single-use guard in `DetectDirectStore` as a correctness issue.

Learnt from: zhaozhaozz
Repo: hw-native-sys/pypto PR: 700
File: src/ir/transforms/op_conversion_registry.cpp:314-332
Timestamp: 2026-04-01T09:49:00.207Z
Learning: In hw-native-sys/pypto `src/ir/transforms/op_conversion_registry.cpp`, the `tensor.batch_matmul` conversion lambda intentionally does NOT forward the `out_dtype` kwarg to `tile.batch_matmul`. The `out_dtype` is handled in two phases: (1) `DeduceTensorBatchMatMulType` in `src/ir/op/tensor_ops/matmul.cpp` records the requested `out_dtype` in the output `TensorType`; (2) `FlattenTileNdTo2D` in `src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` (around lines 753–774) detects when the tile.matmul accumulator dtype (FP32/INT32) differs from the expected result dtype and inserts a cast. This two-phase pattern is the same as `tensor.matmul`. Do not flag the missing `out_dtype` forwarding in the conversion as a bug.

Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 386
File: src/ir/op/tile_ops/reduction.cpp:160-170
Timestamp: 2026-03-06T08:17:18.156Z
Learning: In `src/ir/op/tile_ops/reduction.cpp`, the `REGISTER_OP("tile.sum")` registration intentionally includes a second positional argument `tmp_tile` (described as "Temporary tile (TileType)"), which is used as a hardware workspace tile for reduction. However, `DeduceTileReductionType` currently CHECKs for exactly 1 positional argument. The arity mismatch between the registration and the deduce function may need to be reconciled separately — either by updating `DeduceTileReductionType` to accept and handle `args[1]`, or by confirming the `tmp_tile` is not passed through the type-deduction path.

Learnt from: luohuan19
Repo: hw-native-sys/pypto PR: 786
File: src/codegen/tensor_op_codegen.cpp:61-95
Timestamp: 2026-03-30T17:23:12.620Z
Learning: In hw-native-sys/pypto `src/codegen/tensor_op_codegen.cpp`, the `tensor.create` handler (REGISTER_ORCHESTRATION_OP(tensor_create)) intentionally emits two declarations for every internal tensor: (1) `TensorCreateInfo <var>_ci(...)` used with `add_output(<var>_ci)` for runtime device-memory allocation, and (2) `Tensor <var> = make_tensor_external(nullptr, ...)` as a null-addr placeholder that receives a real address only after `<var> = outs.get_ref(i)`. The codegen tracks these variables in `tensor_create_var_names_` (orchestration_codegen.cpp:928) and guarantees they are always routed through `add_output` before any `add_input` use. DN tensors likewise start as null-addr internally. Do not flag `make_tensor_external(nullptr, ...)` in tensor.create as a use-before-initialize bug; the output-first ordering is enforced by the codegen.

Learnt from: Hzfengsy
Repo: hw-native-sys/pypto PR: 308
File: docs/zh-cn/dev/ir/05-operators.md:80-96
Timestamp: 2026-03-01T09:05:27.522Z
Learning: In the PyPTO C++ codebase, kwargs should be stored as std::vector<std::pair<std::string, std::any>> and looked up with range-based for loops (linear search). Do not switch to std::unordered_map for kwargs. Use the GetKwarg helper in src/ir/op/block_ops/memory.cpp as the standard reference for how this pattern is implemented.

Learnt from: lyfne123
Repo: hw-native-sys/pypto PR: 319
File: src/ir/transforms/pass_context.cpp:0-0
Timestamp: 2026-03-02T08:47:16.111Z
Learning: Do not rely on C++ filesystem for directory creation. cpplint disallows the <filesystem> header; ensure directories for report output are created by the Python layer (e.g., in compile.py using os.makedirs) and pass pre-created directory paths to C++ components like ReportInstrument via their constructors. This pattern should apply to all C++ source files that handle output directories.

Little-oil pushed a commit to Little-oil/pypto that referenced this pull request Apr 8, 2026
- Add int src support to scatter_ signature (docs, DSL, IR validation)
- Add reduce kwarg forwarding in scatter_ DSL and IR layers
- Extract duplicate ND→2D flatten logic into FlattenNdIndicesToTwoD helper
- Add barrier argument validation in MakeSystemBarrierCodegenPTO
- Handle INT64 and general integer ConstInt in PTO codegen
- Fix scatter UT: add Out params to Before programs matching pass output
- Fix InOut param test: align assertion with post-rebase pass behavior
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
python/pypto/ir/op/tensor_ops.py (1)

1002-1007: ⚠️ Potential issue | 🟡 Minor

Infer scalar dtype from input tensor dtype instead of hardcoding INT32/FP32.

Lines 1002–1007 create scalars with fixed dtypes regardless of the input tensor's actual dtype. If input is INT64 or FP16, the hardcoded INT32/FP32 scalar will cause a dtype precision mismatch. The C++ type deduction only validates category (IsInt/IsFloat), not exact dtype. Pass the input tensor's dtype to ConstInt and ConstFloat to ensure compatibility downstream.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/ir/op/tensor_ops.py` around lines 1002 - 1007, The scalar
constructors currently hardcode DataType.INT32/DataType.FP32; instead, when
wrapping Python scalars (the int/float branches) pass the target tensor's dtype
so the scalar matches the input tensor precision—use the input tensor's dtype
(e.g., input.dtype or the relevant tensor Expr's .dtype) as the dtype argument
to ConstInt and ConstFloat rather than DataType.INT32/DataType.FP32; keep the
Expr type check and the TypeError as-is.
🧹 Nitpick comments (2)
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp (2)

663-699: tile.read/tile.write ND→2D index flattening is correctly implemented.

The implementation:

  • Only triggers for >2D tiles (via IsNdTile check)
  • Validates indices are MakeTuple with INTERNAL_CHECK
  • Flattens indices using the helper function
  • Correctly handles both ops (read returns scalar, write returns tile)

However, the code at lines 684-696 has identical logic for both branches (both create a new var and insert into var_map). This could be simplified.

♻️ Simplify duplicated branches
         auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span);
-        if (op_name == "tile.read") {
-          // tile.read returns scalar — assign to var
-          auto new_var =
-              std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
-          result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_));
-          ctx.Insert(assign->var_, new_var);
-        } else {
-          // tile.write returns tile — assign to var and update mapping
-          auto new_var =
-              std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
-          result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_));
-          ctx.Insert(assign->var_, new_var);
-        }
+        // Both tile.read (scalar) and tile.write (tile) produce a result to assign
+        auto new_var =
+            std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
+        result.push_back(std::make_shared<AssignStmt>(new_var, new_call, assign->span_));
+        ctx.Insert(assign->var_, new_var);
         continue;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 663 - 699, The
two branches handling op_name == "tile.read" and else (tile.write) duplicate the
same logic creating new_var, pushing AssignStmt, and calling ctx.Insert;
collapse them into a single shared block after new_call is created: construct
new_var with assign->var_->name_hint_/GetType()/span_, push the new AssignStmt
to result, and call ctx.Insert(assign->var_, new_var); remove the separate
per-op duplicate code paths (keep op_name checks only for any op-specific
behavior if later needed). Use symbols new_call, assign, result, ctx.Insert,
Var, and AssignStmt to locate and replace the duplicated code.

71-76: Edge case: single-element rank-2 input returns uninitialized merged_row.

When rank == 2, the outer loop condition k + 1 < rank means k < 1, so k takes only value 0. The inner loop j = k + 1 = 1 checks j + 1 < rank i.e. 2 < 2 which is false, so no multiplication happens. This is correct behavior.

However, if rank == 1 (single dimension), the loop body never executes and merged_row remains nullptr. While IsNdTile guards ensure rank > 2 before calling this function, there's no explicit precondition check.

🛡️ Consider adding a defensive CHECK
 ExprPtr FlattenNdIndicesToTwoD(const MakeTuplePtr& idx_tuple,
                                const std::vector<ExprPtr>& nd_shape,
                                const std::unordered_map<const Var*, VarPtr>& var_map,
                                const Span& span) {
   const size_t rank = nd_shape.size();
+  CHECK(rank >= 2) << "FlattenNdIndicesToTwoD requires rank >= 2, got " << rank;
   ExprPtr merged_row;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 71 - 76, Add a
defensive precondition and explicit handling for tiny ranks so merged_row can't
remain null: at the start of the function that contains the loop (referencing
merged_row, rank, idx_tuple and the loop using Substitute/MakeMul/MakeAdd) add a
CHECK or assert that rank >= 2 (or return/handle rank == 1 explicitly), and if
rank == 1 set merged_row to Substitute(idx_tuple->elements_[0], var_map) before
the loop so the subsequent logic always has a valid merged_row; keep the
existing loop intact for rank >= 2.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/language/op/tensor_ops.py`:
- Line 806: Update the docstring for the parameter named reduce in
python/pypto/language/op/tensor_ops.py to list all valid values used by the
backend — "none", "add", and "multiply" — so it matches the validation in the
C++ op conversion registry; locate the docstring block that documents reduce and
replace the current quoted values ("add", "multiply") with the full set ("none",
"add", "multiply") and ensure the description clearly states what each mode
means if present.

In `@src/backend/common/pto_ops_common.cpp`:
- Around line 729-748: The code only casts ir::Var in the ensure_index lambda,
so non-Var index expressions (ConstInt, IterArg, or compound Expr like add) keep
their integer scalar type and produce ill-typed arith.muli/arith.addi ... :
index ops; update ensure_index (used above the loop that computes term and in
the flat-offset math) to inspect the expression's type or kind and call
codegen.EmitCastToIndex for any expression that is not already of index type
(not just ir::Var), ensuring term is always the casted SSA string; use the same
pattern around codegen.GetExprAsCode(indices[i]) so all terms used in the
subsequent arith.muli/arith.addi emissions are index-typed.

---

Duplicate comments:
In `@python/pypto/ir/op/tensor_ops.py`:
- Around line 1002-1007: The scalar constructors currently hardcode
DataType.INT32/DataType.FP32; instead, when wrapping Python scalars (the
int/float branches) pass the target tensor's dtype so the scalar matches the
input tensor precision—use the input tensor's dtype (e.g., input.dtype or the
relevant tensor Expr's .dtype) as the dtype argument to ConstInt and ConstFloat
rather than DataType.INT32/DataType.FP32; keep the Expr type check and the
TypeError as-is.

---

Nitpick comments:
In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`:
- Around line 663-699: The two branches handling op_name == "tile.read" and else
(tile.write) duplicate the same logic creating new_var, pushing AssignStmt, and
calling ctx.Insert; collapse them into a single shared block after new_call is
created: construct new_var with assign->var_->name_hint_/GetType()/span_, push
the new AssignStmt to result, and call ctx.Insert(assign->var_, new_var); remove
the separate per-op duplicate code paths (keep op_name checks only for any
op-specific behavior if later needed). Use symbols new_call, assign, result,
ctx.Insert, Var, and AssignStmt to locate and replace the duplicated code.
- Around line 71-76: Add a defensive precondition and explicit handling for tiny
ranks so merged_row can't remain null: at the start of the function that
contains the loop (referencing merged_row, rank, idx_tuple and the loop using
Substitute/MakeMul/MakeAdd) add a CHECK or assert that rank >= 2 (or
return/handle rank == 1 explicitly), and if rank == 1 set merged_row to
Substitute(idx_tuple->elements_[0], var_map) before the loop so the subsequent
logic always has a valid merged_row; keep the existing loop intact for rank >=
2.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4b56488a-1733-4d02-b7e4-6e85cdc4ea04

📥 Commits

Reviewing files that changed from the base of the PR and between 8bcf11e and f955fa9.

📒 Files selected for processing (14)
  • docs/en/user/02-operation_reference.md
  • docs/zh-cn/user/02-operation_reference.md
  • python/pypto/ir/op/tensor_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/tensor_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/codegen/pto/pto_codegen.cpp
  • src/ir/op/tensor_ops/scatter.cpp
  • src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
  • src/ir/transforms/op_conversion_registry.cpp
  • tests/st/runtime/test_scatter.py
  • tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • docs/en/user/02-operation_reference.md
  • python/pypto/language/op/init.py
  • src/codegen/pto/pto_codegen.cpp
  • src/ir/op/tensor_ops/scatter.cpp
  • docs/zh-cn/user/02-operation_reference.md
  • src/ir/transforms/op_conversion_registry.cpp

dim: Dimension along which to scatter
index: Index tensor (N-D, same rank as input) of integer dtype
src: Source tensor (same shape as index) or scalar value
reduce: Optional reduce mode ("add" or "multiply")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if reduce validation exists anywhere in the scatter implementation chain
rg -n "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp | head -20
rg -n "reduce.*add\|reduce.*multiply" --type cpp src/ir/transforms/op_conversion_registry.cpp | head -10

Repository: hw-native-sys/pypto

Length of output: 103


🏁 Script executed:

# Check the context around line 806 in tensor_ops.py
cd python/pypto/language/op && sed -n '795,820p' tensor_ops.py

Repository: hw-native-sys/pypto

Length of output: 957


🏁 Script executed:

# Look at the full scatter function implementation
cd python/pypto/language/op && grep -n "def scatter" tensor_ops.py

Repository: hw-native-sys/pypto

Length of output: 104


🏁 Script executed:

# Check the C++ scatter implementation for reduce handling
cat src/ir/op/tensor_ops/scatter.cpp | head -150

Repository: hw-native-sys/pypto

Length of output: 5645


🏁 Script executed:

# Search for where the reduce parameter is used in C++ code
rg -B5 -A10 "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp

Repository: hw-native-sys/pypto

Length of output: 661


🏁 Script executed:

# Search for reduce parameter usage in tests or examples
rg -n "reduce" --type python python/pypto | grep -E "(test|example)" | head -20

Repository: hw-native-sys/pypto

Length of output: 93


🏁 Script executed:

# Search for scatter tests
fd -e py "test.*scatter" python/

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Look for where scatter_ is actually called with reduce parameter
rg -B2 -A2 'scatter_.*reduce' --type python python/

Repository: hw-native-sys/pypto

Length of output: 93


🏁 Script executed:

# Search for scatter tests and reduce usage
fd -type f -name "*scatter*" python/

Repository: hw-native-sys/pypto

Length of output: 233


🏁 Script executed:

# Search for reduce parameter usage in Python files
rg "reduce\s*=" --type py python/ | head -20

Repository: hw-native-sys/pypto

Length of output: 396


🏁 Script executed:

# Look at the actual backend to understand what reduce modes are supported
rg -n "reduce" src/ | grep -E "(add|multiply)" | head -20

Repository: hw-native-sys/pypto

Length of output: 541


🏁 Script executed:

# Check if there are any comments or documentation about valid reduce modes
rg -B3 -A3 "reduce.*add.*multiply" --type py --type cpp --type txt

Repository: hw-native-sys/pypto

Length of output: 1343


Update docstring to include all valid reduce values.

The Python docstring lists only "add" and "multiply", but the C++ backend (src/ir/transforms/op_conversion_registry.cpp:546) validates against three modes: "none", "add", and "multiply". Update the docstring to reflect the complete set of valid values.

-        reduce: Optional reduce mode ("add" or "multiply")
+        reduce: Optional reduce mode ("none", "add", or "multiply")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
reduce: Optional reduce mode ("add" or "multiply")
reduce: Optional reduce mode ("none", "add", or "multiply")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/language/op/tensor_ops.py` at line 806, Update the docstring for
the parameter named reduce in python/pypto/language/op/tensor_ops.py to list all
valid values used by the backend — "none", "add", and "multiply" — so it matches
the validation in the C++ op conversion registry; locate the docstring block
that documents reduce and replace the current quoted values ("add", "multiply")
with the full set ("none", "add", "multiply") and ensure the description clearly
states what each mode means if present.

Comment on lines +729 to +748
// Helper: ensure an index element SSA value has `index` type.
// If the expression is a non-index integer (e.g. i32 from tile.read on an
// INT32 tile), emit arith.index_cast to convert it.
auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
if (auto var = ir::As<ir::Var>(expr)) {
return codegen.EmitCastToIndex(var, ssa);
}
return ssa;
};

// For each dimension i, compute: index[i] * (shape[i+1] * shape[i+2] * ... * shape[rank-1])
// then sum all terms with arith.addi.
std::string accumulator;
for (size_t i = 0; i < indices.size(); ++i) {
if (i > 0) idx_oss << " + ";
idx_oss << codegen.GetExprAsCode(indices[i]);
std::string term = ensure_index(indices[i], codegen.GetExprAsCode(indices[i]));
// Multiply by each trailing dimension size
for (size_t j = i + 1; j < shape.size(); ++j) {
idx_oss << " * " << codegen.GetExprAsCode(shape[j]);
std::string dim = codegen.GetExprAsCode(shape[j]);
std::string tmp = codegen.NewTemp();
codegen.Emit(tmp + " = arith.muli " + term + ", " + dim + " : index");
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cast all non-index scalar index expressions before the flat-offset math.

Line 733 only normalizes plain ir::Var. If a tuple element is an IterArg, an integer ConstInt, or a computed scalar like idx32 + 1, term stays i32/i64 and the emitted arith.muli / arith.addi ... : index becomes ill-typed.

🔧 Suggested fix
-  auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
-    if (auto var = ir::As<ir::Var>(expr)) {
-      return codegen.EmitCastToIndex(var, ssa);
-    }
-    return ssa;
-  };
+  auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
+    auto scalar_type = As<ScalarType>(expr->GetType());
+    INTERNAL_CHECK(scalar_type) << "flat index expression must be scalar";
+    if (scalar_type->dtype_ == DataType::INDEX) {
+      return ssa;
+    }
+    CHECK(!scalar_type->dtype_.IsFloat()) << "flat index expression must be integer/index typed";
+    std::string idx_ssa = codegen.NewTemp();
+    codegen.Emit(idx_ssa + " = arith.index_cast " + ssa + " : " +
+                 codegen.GetTypeString(scalar_type->dtype_) + " to index");
+    return idx_ssa;
+  };
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/common/pto_ops_common.cpp` around lines 729 - 748, The code only
casts ir::Var in the ensure_index lambda, so non-Var index expressions
(ConstInt, IterArg, or compound Expr like add) keep their integer scalar type
and produce ill-typed arith.muli/arith.addi ... : index ops; update ensure_index
(used above the loop that computes term and in the flat-offset math) to inspect
the expression's type or kind and call codegen.EmitCastToIndex for any
expression that is not already of index type (not just ir::Var), ensuring term
is always the casted SSA string; use the same pattern around
codegen.GetExprAsCode(indices[i]) so all terms used in the subsequent
arith.muli/arith.addi emissions are index-typed.

- Add int src support to scatter_ signature (docs, DSL, IR validation)
- Add reduce kwarg forwarding in scatter_ DSL and IR layers
- Add scalar dtype pair validation (float src → float input, int src → int input)
- Extract duplicate ND→2D flatten logic into FlattenNdIndicesToTwoD helper
- Add barrier argument validation in MakeSystemBarrierCodegenPTO
- Extract _scatter_handler from _register_ops to fix ruff PLR0915
- Fix scatter UT: add Out params to Before programs matching pass output
- Fix InOut param test: align assertion with post-rebase pass behavior
- Apply clang-format and ruff-format fixes
Add a minimal 2x8 scatter test case with unique src values for
easy tracing. Add diagnostic prints to all scatter test methods
showing dim, index, src, and expected values.
Copy link
Copy Markdown
Collaborator

@lyfne123 lyfne123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的用例不足以覆盖修改的代码


static std::string MakeSystemBarrierCodegenPTO(const std::string& pipe_name, const CallPtr& op,
codegen::CodegenBase& codegen_base) {
CHECK(op->args_.empty()) << "system.barrier_" << pipe_name << " expects 0 arguments, got "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要插入barrier,这个工作应该是ptoas做的

@Hzfengsy Hzfengsy changed the title feat(op): implement tensor.scatter_ element-level scatter (#677) feat(op): implement tensor.scatter_ element-level scatter Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants