Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions src/ir/transforms/split_vector_kernel_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,38 @@ std::vector<std::pair<std::string, std::any>> WithSplitAttr(const FunctionPtr& f
return attrs;
}

bool IsSingletonDim(const ExprPtr& dim_size) {
if (auto ci = std::dynamic_pointer_cast<const ConstInt>(dim_size)) {
return ci->value_ == 1;
}
return false;
}

bool IsReduceOnSplitAxis(const CallPtr& call, int split_dim) {
if (!call->op_) return false;
const auto& name = call->op_->name_;

auto input_tile_type = [&]() -> std::shared_ptr<const TileType> {
if (call->args_.empty()) return nullptr;
return std::dynamic_pointer_cast<const TileType>(call->args_[0]->GetType());
};

if (name == "tile.row_sum" || name == "tile.row_max" || name == "tile.row_min") {
auto tt = input_tile_type();
int last_axis = tt ? static_cast<int>(tt->shape_.size()) - 1 : 1;
return split_dim == last_axis;
}
if (name == "tile.sum" || name == "tile.max" || name == "tile.min") {
int axis = call->GetKwarg<int>("axis", -1);
auto tt = input_tile_type();
if (axis < 0 && tt) {
axis = static_cast<int>(tt->shape_.size()) + axis;
}
return axis == split_dim;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return false;
}

ExprPtr ComputeHalfDimSize(const ExprPtr& dim_size) {
if (auto ci = std::dynamic_pointer_cast<const ConstInt>(dim_size)) {
if ((ci->value_ % 2) != 0) {
Expand Down Expand Up @@ -297,9 +329,17 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
return std::make_shared<AssignStmt>(new_var, new_call, assign->span_);
}

// AIV only: tile.load — halve result shape, halve shape/valid_shape args, adjust offset
// AIV only: tile.load — halve result shape, halve shape/valid_shape args, adjust offset.
// Singleton split-dim tiles (e.g. broadcast [1, 128] under UP_DOWN) are preserved as-is.
if (is_aiv && op_name == "tile.load" && call->args_.size() >= 4) {
auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
bool is_singleton =
tt && split_dim < static_cast<int>(tt->shape_.size()) && IsSingletonDim(tt->shape_[split_dim]);

if (is_singleton) {
return stmt;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

ExprPtr half_dim_size;
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
Expand Down Expand Up @@ -341,10 +381,21 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
}
}

// AIV only: any other op producing TileType — halve result shape (and static shape args when present)
// AIV only: any other op producing TileType — halve result shape (and static shape args when present).
// Reject reduce ops that reduce on the split axis (partial reduction is semantically incorrect).
// Skip halving when the output split-dim is singleton (broadcast / degenerate tiles).
if (is_aiv) {
if (IsReduceOnSplitAxis(call, split_dim)) {
throw pypto::ValueError("SplitVectorKernel: reduce op '" + op_name +
"' reduces on the split axis (dim " + std::to_string(split_dim) +
"); partial reduction in a split kernel is not supported");
}

auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
if (IsSingletonDim(tt->shape_[split_dim])) {
return stmt;
}
auto half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
auto new_result_type = HalveTileShape(call->GetType(), split_dim);
std::vector<ExprPtr> new_args = call->args_;
Expand Down
85 changes: 85 additions & 0 deletions tests/ut/ir/transforms/test_split_vector_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,59 @@ def main_aic(self, x: pl.Tensor[[16, 128], pl.BF16]):

_assert_split_matches_expected(Before, Expected)

def test_singleton_broadcast_tile_preserved(self):
"""Broadcast tile [1, 128] on split axis dim0 must stay unchanged under UP_DOWN."""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
def main_aiv(
self,
data: pl.Tensor[[16, 128], pl.FP32],
gamma: pl.Tensor[[1, 128], pl.FP32],
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
)
gamma_tile: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
gamma, [0, 0], [1, 128], target_memory=pl.MemorySpace.Vec
)
result: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.col_expand_mul(prev, gamma_tile)
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(result, [0, 0], out_0)
return out_0_store

actual = _run_split_vector_kernel(Before)
printed = python_print(actual)
main_aiv = actual.get_function("main_aiv")
assert main_aiv is not None
assert "pl.tile.get_subblock_idx()" in printed
assert "pl.tile.load(data__ssa_v0, [0 + subblock_idx * 8, 0], [8, 128], [8, 128]" in printed
assert "pl.tile.load(gamma__ssa_v0, [0, 0], [1, 128], [1, 128]" in printed
assert "pl.tile.col_expand_mul(" in printed
assert "pl.tile.store(" in printed

def test_reduce_on_split_axis_rejected(self):
"""Reduce on split axis (dim0 under UP_DOWN) must raise ValueError."""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
def main_aiv(
self,
data: pl.Tensor[[16, 128], pl.FP32],
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
)
reduced: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.sum(prev, axis=0, keepdim=True)
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(reduced, [0, 0], out_0)
return out_0_store

with pytest.raises(ValueError, match="reduces on the split axis"):
_run_split_vector_kernel(Before)


class TestSplitVectorKernelLeftRight:
"""Tests for SplitMode.LEFT_RIGHT (halve width, dim 1)."""
Expand Down Expand Up @@ -515,3 +568,35 @@ def main_aiv(
return out_0_store

_assert_split_matches_expected(Before, Expected)

def test_singleton_broadcast_tile_preserved_left_right(self):
"""Broadcast tile [128, 1] on split axis dim1 must stay unchanged under LEFT_RIGHT."""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.LEFT_RIGHT})
def main_aiv(
self,
data: pl.Tensor[[16, 128], pl.FP32],
gamma: pl.Tensor[[16, 1], pl.FP32],
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
)
gamma_tile: pl.Tile[[16, 1], pl.FP32, pl.MemorySpace.Vec] = pl.load(
gamma, [0, 0], [16, 1], target_memory=pl.MemorySpace.Vec
)
result: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.row_expand_mul(prev, gamma_tile)
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(result, [0, 0], out_0)
return out_0_store

actual = _run_split_vector_kernel(Before)
printed = python_print(actual)
main_aiv = actual.get_function("main_aiv")
assert main_aiv is not None
assert "pl.tile.get_subblock_idx()" in printed
assert "pl.tile.load(data__ssa_v0, [0, 0 + subblock_idx * 64], [16, 64], [16, 64]" in printed
assert "pl.tile.load(gamma__ssa_v0, [0, 0], [16, 1], [16, 1]" in printed
assert "pl.tile.row_expand_mul(" in printed
assert "pl.tile.store(" in printed
Loading