Skip to content

Commit cd66135

Browse files
committed
fix(pass): preserve singleton broadcast dims in SplitVectorKernel
Redesign the split decision algorithm in SplitVectorKernel to be op-semantics-aware instead of unconditionally halving all tile dims: - Add IsSingletonDim check: tiles with split-axis extent == 1 (e.g. broadcast [1, 128] under UP_DOWN) are now preserved as-is without halving shape, adjusting offsets, or tracking in tile_vars - Add IsReduceOnSplitAxis detection: reduce ops (tile.sum/max/min, tile.row_sum/max/min) that reduce on the split axis are rejected with a clear error, since partial reduction is semantically incorrect - Add regression tests for both UP_DOWN and LEFT_RIGHT singleton broadcast scenarios, plus a reduce-on-split-axis rejection test Fixes #976 Closes #975 Made-with: Cursor
1 parent 460d257 commit cd66135

2 files changed

Lines changed: 131 additions & 2 deletions

File tree

src/ir/transforms/split_vector_kernel_pass.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ std::vector<std::pair<std::string, std::any>> WithSplitAttr(const FunctionPtr& f
123123
return attrs;
124124
}
125125

126+
bool IsSingletonDim(const ExprPtr& dim_size) {
127+
if (auto ci = std::dynamic_pointer_cast<const ConstInt>(dim_size)) {
128+
return ci->value_ == 1;
129+
}
130+
return false;
131+
}
132+
133+
bool IsReduceOnSplitAxis(const CallPtr& call, int split_dim) {
134+
if (!call->op_) return false;
135+
const auto& name = call->op_->name_;
136+
if (name == "tile.row_sum" || name == "tile.row_max" || name == "tile.row_min") {
137+
return split_dim == 1;
138+
}
139+
if (name == "tile.sum" || name == "tile.max" || name == "tile.min") {
140+
int axis = call->GetKwarg<int>("axis", -1);
141+
auto tt = std::dynamic_pointer_cast<const TileType>(
142+
std::dynamic_pointer_cast<const Call>(call)->args_.empty() ? nullptr : call->args_[0]->GetType());
143+
if (axis < 0 && tt) {
144+
axis = static_cast<int>(tt->shape_.size()) + axis;
145+
}
146+
return axis == split_dim;
147+
}
148+
return false;
149+
}
150+
126151
ExprPtr ComputeHalfDimSize(const ExprPtr& dim_size) {
127152
if (auto ci = std::dynamic_pointer_cast<const ConstInt>(dim_size)) {
128153
if ((ci->value_ % 2) != 0) {
@@ -297,9 +322,17 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
297322
return std::make_shared<AssignStmt>(new_var, new_call, assign->span_);
298323
}
299324

300-
// AIV only: tile.load — halve result shape, halve shape/valid_shape args, adjust offset
325+
// AIV only: tile.load — halve result shape, halve shape/valid_shape args, adjust offset.
326+
// Singleton split-dim tiles (e.g. broadcast [1, 128] under UP_DOWN) are preserved as-is.
301327
if (is_aiv && op_name == "tile.load" && call->args_.size() >= 4) {
302328
auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
329+
bool is_singleton =
330+
tt && split_dim < static_cast<int>(tt->shape_.size()) && IsSingletonDim(tt->shape_[split_dim]);
331+
332+
if (is_singleton) {
333+
return stmt;
334+
}
335+
303336
ExprPtr half_dim_size;
304337
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
305338
half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
@@ -341,10 +374,21 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
341374
}
342375
}
343376

344-
// AIV only: any other op producing TileType — halve result shape (and static shape args when present)
377+
// AIV only: any other op producing TileType — halve result shape (and static shape args when present).
378+
// Reject reduce ops that reduce on the split axis (partial reduction is semantically incorrect).
379+
// Skip halving when the output split-dim is singleton (broadcast / degenerate tiles).
345380
if (is_aiv) {
381+
if (IsReduceOnSplitAxis(call, split_dim)) {
382+
throw pypto::ValueError("SplitVectorKernel: reduce op '" + op_name +
383+
"' reduces on the split axis (dim " + std::to_string(split_dim) +
384+
"); partial reduction in a split kernel is not supported");
385+
}
386+
346387
auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
347388
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
389+
if (IsSingletonDim(tt->shape_[split_dim])) {
390+
return stmt;
391+
}
348392
auto half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
349393
auto new_result_type = HalveTileShape(call->GetType(), split_dim);
350394
std::vector<ExprPtr> new_args = call->args_;

tests/ut/ir/transforms/test_split_vector_kernel.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,59 @@ def main_aic(self, x: pl.Tensor[[16, 128], pl.BF16]):
412412

413413
_assert_split_matches_expected(Before, Expected)
414414

415+
def test_singleton_broadcast_tile_preserved(self):
416+
"""Broadcast tile [1, 128] on split axis dim0 must stay unchanged under UP_DOWN."""
417+
418+
@pl.program
419+
class Before:
420+
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
421+
def main_aiv(
422+
self,
423+
data: pl.Tensor[[16, 128], pl.FP32],
424+
gamma: pl.Tensor[[1, 128], pl.FP32],
425+
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
426+
) -> pl.Tensor[[16, 128], pl.FP32]:
427+
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
428+
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
429+
)
430+
gamma_tile: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
431+
gamma, [0, 0], [1, 128], target_memory=pl.MemorySpace.Vec
432+
)
433+
result: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.col_expand_mul(prev, gamma_tile)
434+
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(result, [0, 0], out_0)
435+
return out_0_store
436+
437+
actual = _run_split_vector_kernel(Before)
438+
printed = python_print(actual)
439+
main_aiv = actual.get_function("main_aiv")
440+
assert main_aiv is not None
441+
assert "pl.tile.get_subblock_idx()" in printed
442+
assert "pl.tile.load(data__ssa_v0, [0 + subblock_idx * 8, 0], [8, 128], [8, 128]" in printed
443+
assert "pl.tile.load(gamma__ssa_v0, [0, 0], [1, 128], [1, 128]" in printed
444+
assert "pl.tile.col_expand_mul(" in printed
445+
assert "pl.tile.store(" in printed
446+
447+
def test_reduce_on_split_axis_rejected(self):
448+
"""Reduce on split axis (dim0 under UP_DOWN) must raise ValueError."""
449+
450+
@pl.program
451+
class Before:
452+
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
453+
def main_aiv(
454+
self,
455+
data: pl.Tensor[[16, 128], pl.FP32],
456+
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
457+
) -> pl.Tensor[[16, 128], pl.FP32]:
458+
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
459+
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
460+
)
461+
reduced: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.sum(prev, axis=0, keepdim=True)
462+
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(reduced, [0, 0], out_0)
463+
return out_0_store
464+
465+
with pytest.raises(Exception, match="reduces on the split axis"):
466+
_run_split_vector_kernel(Before)
467+
415468

416469
class TestSplitVectorKernelLeftRight:
417470
"""Tests for SplitMode.LEFT_RIGHT (halve width, dim 1)."""
@@ -515,3 +568,35 @@ def main_aiv(
515568
return out_0_store
516569

517570
_assert_split_matches_expected(Before, Expected)
571+
572+
def test_singleton_broadcast_tile_preserved_left_right(self):
573+
"""Broadcast tile [128, 1] on split axis dim1 must stay unchanged under LEFT_RIGHT."""
574+
575+
@pl.program
576+
class Before:
577+
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.LEFT_RIGHT})
578+
def main_aiv(
579+
self,
580+
data: pl.Tensor[[16, 128], pl.FP32],
581+
gamma: pl.Tensor[[16, 1], pl.FP32],
582+
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
583+
) -> pl.Tensor[[16, 128], pl.FP32]:
584+
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
585+
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
586+
)
587+
gamma_tile: pl.Tile[[16, 1], pl.FP32, pl.MemorySpace.Vec] = pl.load(
588+
gamma, [0, 0], [16, 1], target_memory=pl.MemorySpace.Vec
589+
)
590+
result: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.row_expand_mul(prev, gamma_tile)
591+
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(result, [0, 0], out_0)
592+
return out_0_store
593+
594+
actual = _run_split_vector_kernel(Before)
595+
printed = python_print(actual)
596+
main_aiv = actual.get_function("main_aiv")
597+
assert main_aiv is not None
598+
assert "pl.tile.get_subblock_idx()" in printed
599+
assert "pl.tile.load(data__ssa_v0, [0, 0 + subblock_idx * 64], [16, 64], [16, 64]" in printed
600+
assert "pl.tile.load(gamma__ssa_v0, [0, 0], [16, 1], [16, 1]" in printed
601+
assert "pl.tile.row_expand_mul(" in printed
602+
assert "pl.tile.store(" in printed

0 commit comments

Comments
 (0)