fix(pass): preserve singleton broadcast dims in SplitVectorKernel#984
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughPreserve singleton split-axis dimensions for broadcast tiles in SplitVectorKernel and add predicates to detect singleton dims and reductions on the split axis. Halving/offset adjustments are skipped for singleton split dims, and reductions that target the split axis are rejected. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enhances the split_vector_kernel_pass by adding logic to handle singleton dimensions and prevent invalid reductions on the split axis. It introduces helper functions IsSingletonDim and IsReduceOnSplitAxis to ensure that singleton tiles are preserved during splitting and that unsupported partial reductions trigger a ValueError. Corresponding unit tests have been added to verify these cases. Review feedback recommends optimizing performance by using GetKind() for type checking and removing a redundant cast in the IsReduceOnSplitAxis function.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/ut/ir/transforms/test_split_vector_kernel.py (1)
447-466: Please add one row-reduction regression too.The C++ change has a separate
tile.row_sum/max/minbranch, but this file only exercises the genericpl.sum(axis=0)path. One LEFT_RIGHT row-reduction case would lock down that branch as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/ir/transforms/test_split_vector_kernel.py` around lines 447 - 466, Add a second test that mirrors test_reduce_on_split_axis_rejected but exercises the LEFT_RIGHT split and a row-reduction: create a new test (e.g., test_reduce_on_split_axis_rejected_row) that defines a Before program with `@pl.function`(..., attrs={"split": pl.SplitMode.LEFT_RIGHT}) and loads a Tile like in the original test, then calls pl.sum(prev, axis=1, keepdim=True) (row reduction) and asserts it raises with the same "reduces on the split axis" message by calling _run_split_vector_kernel(Before).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/transforms/split_vector_kernel_pass.cpp`:
- Around line 325-334: The early return in split_vector_kernel_pass.cpp that
bypasses split tracking for singleton-result tiles (when is_aiv && op_name ==
"tile.load" and IsSingletonDim on TileType) is too permissive; update the check
so only proven broadcast-safe ops are exempt (e.g., explicitly allow known
broadcast-only ops) or conservatively reject ambiguous singleton-result cases so
tile_vars still track offsets; concretely, in the tile.load / tile.slice
handling replace the unconditional is_singleton return with a predicate that
verifies the op is in a whitelist of broadcast-safe ops or that the source is
split-invariant, and ensure the same change is applied to the analogous
early-return branch later (the one referencing tile.slice/tile.store and
tile_vars tracking).
---
Nitpick comments:
In `@tests/ut/ir/transforms/test_split_vector_kernel.py`:
- Around line 447-466: Add a second test that mirrors
test_reduce_on_split_axis_rejected but exercises the LEFT_RIGHT split and a
row-reduction: create a new test (e.g., test_reduce_on_split_axis_rejected_row)
that defines a Before program with `@pl.function`(..., attrs={"split":
pl.SplitMode.LEFT_RIGHT}) and loads a Tile like in the original test, then calls
pl.sum(prev, axis=1, keepdim=True) (row reduction) and asserts it raises with
the same "reduces on the split axis" message by calling
_run_split_vector_kernel(Before).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 6d371d74-94ef-4338-935c-5ceaee3b9ab2
📒 Files selected for processing (2)
src/ir/transforms/split_vector_kernel_pass.cpptests/ut/ir/transforms/test_split_vector_kernel.py
Redesign the split decision algorithm in SplitVectorKernel to be op-semantics-aware instead of unconditionally halving all tile dims: - Add IsSingletonDim check: tiles with split-axis extent == 1 (e.g. broadcast [1, 128] under UP_DOWN) are now preserved as-is without halving shape, adjusting offsets, or tracking in tile_vars - Add IsReduceOnSplitAxis detection: reduce ops (tile.sum/max/min, tile.row_sum/max/min) that reduce on the split axis are rejected with a clear error, since partial reduction is semantically incorrect - Add regression tests for both UP_DOWN and LEFT_RIGHT singleton broadcast scenarios, plus a reduce-on-split-axis rejection test Fixes hw-native-sys#976 Closes hw-native-sys#975 Made-with: Cursor
cd66135 to
5b7d946
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/ir/transforms/split_vector_kernel_pass.cpp (1)
325-334:⚠️ Potential issue | 🟠 MajorDon’t treat every singleton split-axis tile as split-invariant.
These early returns still skip split rewriting for any
[1, N]/[M, 1]producer. A non-broadcast singletontile.loador singleton-result op such astile.slicecan now bypass offset rewriting andtile_varstracking, so later stores may duplicate the same row/column from both lanes. Preserve only proven broadcast-safe cases, or reject ambiguous singleton producers instead of returning the original stmt.Based on learnings,
tile.sliceis already emitted in this IR insrc/ir/transforms/flatten_tile_nd_to_2d_pass.cpp, so singleton-result tiles are not broadcast-only by construction.Also applies to: 377-391
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ir/transforms/split_vector_kernel_pass.cpp` around lines 325 - 334, The early-return that preserves singleton split-dim tiles (in the is_aiv && op_name == "tile.load" branch using TileType, split_dim and IsSingletonDim) is too permissive and skips necessary offset rewriting and tile_vars tracking; instead, change the logic to only bypass rewriting for proven broadcast cases (e.g., detect explicit broadcast/stride metadata or inspect the call's shape/valid_shape args to confirm the singleton is a broadcast source), and otherwise do not return early—allow the existing tile.load/tile.slice rewrite path to apply or explicitly reject/raise for ambiguous singleton producers so stores don't duplicate rows/columns; apply the same fix to the analogous block around lines 377-391.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/transforms/split_vector_kernel_pass.cpp`:
- Around line 133-147: In IsReduceOnSplitAxis, tile.row_sum/row_max/row_min
should check the input tile's last axis instead of hardcoding split_dim == 1;
obtain the input TileType (like the tile.sum branch does) from
call->args_[0]->GetType(), compute last_axis = tt ?
static_cast<int>(tt->shape_.size()) - 1 : 1 (or bail if tt is null), and return
last_axis == split_dim; update the tile.row_* branch in IsReduceOnSplitAxis to
use that last-axis comparison so rank>2 tiles are handled correctly.
In `@tests/ut/ir/transforms/test_split_vector_kernel.py`:
- Around line 465-466: The test currently uses pytest.raises(Exception,
match="reduces on the split axis") which is too broad; update the assertion to
expect the specific exception type (e.g. ValueError or the project-specific
pypto exception) by changing the pytest.raises call around
_run_split_vector_kernel(Before) to pytest.raises(ValueError, match="reduces on
the split axis") (or the named pypto exception) so the test only passes when the
split-axis reduction check fails in _run_split_vector_kernel.
---
Duplicate comments:
In `@src/ir/transforms/split_vector_kernel_pass.cpp`:
- Around line 325-334: The early-return that preserves singleton split-dim tiles
(in the is_aiv && op_name == "tile.load" branch using TileType, split_dim and
IsSingletonDim) is too permissive and skips necessary offset rewriting and
tile_vars tracking; instead, change the logic to only bypass rewriting for
proven broadcast cases (e.g., detect explicit broadcast/stride metadata or
inspect the call's shape/valid_shape args to confirm the singleton is a
broadcast source), and otherwise do not return early—allow the existing
tile.load/tile.slice rewrite path to apply or explicitly reject/raise for
ambiguous singleton producers so stores don't duplicate rows/columns; apply the
same fix to the analogous block around lines 377-391.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 51d839b8-6294-4df8-8543-ebcc6babaa93
📒 Files selected for processing (2)
src/ir/transforms/split_vector_kernel_pass.cpptests/ut/ir/transforms/test_split_vector_kernel.py
- Remove redundant dynamic_pointer_cast<Call> in IsReduceOnSplitAxis; extract input_tile_type lambda to avoid repetition - Fix tile.row_* reduce detection to use last axis of input tile instead of hardcoded dim 1, supporting rank>2 tiles - Narrow pytest.raises(Exception) to pytest.raises(ValueError) for more targeted reduce-on-split-axis rejection test Made-with: Cursor
Summary
SplitVectorKernelto be op-semantics-aware instead of unconditionally halving all tile dimensions on the split axisIsSingletonDimcheck: tiles with split-axis extent == 1 (e.g. broadcast[1, 128]underUP_DOWN) are preserved as-is without halving, offset adjustment, or tile trackingIsReduceOnSplitAxisdetection: reduce ops (tile.sum/max/min,tile.row_sum/max/min) that reduce on the split axis are rejected with a clear error since partial reduction is semantically incorrectFixes #976
Closes #975
Test plan
test_split_vector_kernel.pytests pass (12 existing + 3 new)tests/ut/ir/transforms/tests pass with no regressionsMade with Cursor