Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions src/ir/transforms/split_vector_kernel_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ std::vector<std::pair<std::string, std::any>> WithSplitAttr(const FunctionPtr& f
return attrs;
}

ExprPtr ComputeHalfDimSize(const ExprPtr& dim_size) {
struct SplitDimInfo {
ExprPtr dim_size;
bool split_applied;
};

SplitDimInfo ComputeSplitDimInfo(const ExprPtr& dim_size, bool preserve_singleton = false) {
if (auto ci = std::dynamic_pointer_cast<const ConstInt>(dim_size)) {
if (preserve_singleton && ci->value_ == 1) {
return SplitDimInfo{dim_size, false};
}
if ((ci->value_ % 2) != 0) {
throw pypto::ValueError("SplitVectorKernel requires an even split dimension, got " +
std::to_string(ci->value_));
}
return std::make_shared<ConstInt>(ci->value_ / 2, ci->dtype(), ci->span_);
return SplitDimInfo{std::make_shared<ConstInt>(ci->value_ / 2, ci->dtype(), ci->span_), true};
}
auto two = std::make_shared<ConstInt>(2, GetScalarDtype(dim_size), dim_size->span_);
return MakeFloorDiv(dim_size, two, dim_size->span_);
return SplitDimInfo{MakeFloorDiv(dim_size, two, dim_size->span_), true};
}

CallPtr RebuildCallWithSplit(const CallPtr& call, int split_int) {
Expand All @@ -152,31 +160,31 @@ CallPtr RebuildCallWithSplit(const CallPtr& call, int split_int) {
return std::make_shared<Call>(call->op_, call->args_, std::move(new_kwargs), call->GetType(), call->span_);
}

TypePtr HalveTileShape(const TypePtr& type, int dim) {
TypePtr HalveTileShape(const TypePtr& type, int dim, bool preserve_singleton = false) {
auto tt = std::dynamic_pointer_cast<const TileType>(type);
if (!tt || dim < 0 || dim >= static_cast<int>(tt->shape_.size())) return type;

std::vector<ExprPtr> new_shape = tt->shape_;
new_shape[dim] = ComputeHalfDimSize(tt->shape_[dim]);
new_shape[dim] = ComputeSplitDimInfo(tt->shape_[dim], preserve_singleton).dim_size;

// Keep TileView.valid_shape consistent with halved physical shape (was left at pre-split size).
std::optional<TileView> new_tile_view = tt->tile_view_;
if (const auto& tile_view = tt->tile_view_; tile_view.has_value()) {
TileView tv = tile_view.value();
if (dim < static_cast<int>(tv.valid_shape.size())) {
tv.valid_shape[dim] = ComputeHalfDimSize(tv.valid_shape[dim]);
tv.valid_shape[dim] = ComputeSplitDimInfo(tv.valid_shape[dim], preserve_singleton).dim_size;
}
new_tile_view = std::move(tv);
}

return std::make_shared<TileType>(new_shape, tt->dtype_, tt->memref_, new_tile_view, tt->memory_space_);
}

ExprPtr HalveTupleElement(const ExprPtr& tuple_expr, int dim) {
ExprPtr HalveTupleElement(const ExprPtr& tuple_expr, int dim, bool preserve_singleton = false) {
auto tuple = std::dynamic_pointer_cast<const MakeTuple>(tuple_expr);
if (!tuple || dim < 0 || dim >= static_cast<int>(tuple->elements_.size())) return tuple_expr;
std::vector<ExprPtr> new_elements = tuple->elements_;
new_elements[dim] = ComputeHalfDimSize(new_elements[dim]);
new_elements[dim] = ComputeSplitDimInfo(new_elements[dim], preserve_singleton).dim_size;
return std::make_shared<MakeTuple>(std::move(new_elements), tuple_expr->span_);
}

Expand Down Expand Up @@ -289,7 +297,7 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
auto new_var =
std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
TileInfo info{ComputeHalfDimSize(tt->shape_[split_dim])};
TileInfo info{ComputeSplitDimInfo(tt->shape_[split_dim]).dim_size};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tile.tpop_from_aic operation handling is incomplete and will still crash when encountering a singleton dimension on the split axis.

  1. This line calls ComputeSplitDimInfo with the default preserve_singleton=false, which will throw a ValueError for dimensions of size 1.
  2. Additionally, tile_vars should only be updated if the split was actually applied (split_applied == true). Otherwise, subsequent tile.store operations will incorrectly adjust offsets for the non-split singleton tile.
  3. The helper RebuildTpopWithHalvedShape (called at line 296) also needs to be updated to support and pass the preserve_singleton flag to HalveTileShape to avoid a similar crash during type reconstruction.
References
  1. When an AIV operation produces a TileType, ensure that any shape-related arguments within the Call itself are also updated (e.g., halved) to maintain type consistency and prevent failures in subsequent passes or codegen.

tile_vars[assign->var_.get()] = info;
tile_vars[new_var.get()] = info;
}
Expand All @@ -300,24 +308,24 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
// AIV only: tile.load — halve result shape, halve shape/valid_shape args, adjust offset
if (is_aiv && op_name == "tile.load" && call->args_.size() >= 4) {
auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
ExprPtr half_dim_size;
std::optional<SplitDimInfo> split_info;
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
split_info = ComputeSplitDimInfo(tt->shape_[split_dim], /*preserve_singleton=*/true);
}

auto new_result_type = HalveTileShape(call->GetType(), split_dim);
auto new_result_type = HalveTileShape(call->GetType(), split_dim, /*preserve_singleton=*/true);
std::vector<ExprPtr> new_args = call->args_;
if (half_dim_size) {
new_args[1] = AdjustOffsets(call->args_[1], split_dim, half_dim_size, subblock_idx);
if (split_info.has_value() && split_info->split_applied) {
new_args[1] = AdjustOffsets(call->args_[1], split_dim, split_info->dim_size, subblock_idx);
}
new_args[2] = HalveTupleElement(call->args_[2], split_dim);
new_args[3] = HalveTupleElement(call->args_[3], split_dim);
new_args[2] = HalveTupleElement(call->args_[2], split_dim, /*preserve_singleton=*/true);
new_args[3] = HalveTupleElement(call->args_[3], split_dim, /*preserve_singleton=*/true);

auto new_call =
std::make_shared<Call>(call->op_, std::move(new_args), call->kwargs_, new_result_type, call->span_);
auto new_var = std::make_shared<Var>(assign->var_->name_hint_, new_result_type, assign->var_->span_);
if (half_dim_size) {
TileInfo info{half_dim_size};
if (split_info.has_value() && split_info->split_applied) {
TileInfo info{split_info->dim_size};
tile_vars[assign->var_.get()] = info;
tile_vars[new_var.get()] = info;
}
Expand Down Expand Up @@ -345,18 +353,20 @@ StmtPtr ProcessStmt(const StmtPtr& stmt, SplitMode mode, int split_int, int spli
if (is_aiv) {
auto tt = std::dynamic_pointer_cast<const TileType>(call->GetType());
if (tt && split_dim < static_cast<int>(tt->shape_.size())) {
auto half_dim_size = ComputeHalfDimSize(tt->shape_[split_dim]);
auto new_result_type = HalveTileShape(call->GetType(), split_dim);
auto split_info = ComputeSplitDimInfo(tt->shape_[split_dim], /*preserve_singleton=*/true);
auto new_result_type = HalveTileShape(call->GetType(), split_dim, /*preserve_singleton=*/true);
std::vector<ExprPtr> new_args = call->args_;
if ((op_name == "tile.full" || op_name == "tile.create") && call->args_.size() >= 1) {
new_args[0] = HalveTupleElement(call->args_[0], split_dim);
new_args[0] = HalveTupleElement(call->args_[0], split_dim, /*preserve_singleton=*/true);
}
auto new_call = std::make_shared<Call>(call->op_, std::move(new_args), call->kwargs_, new_result_type,
call->span_);
auto new_var = std::make_shared<Var>(assign->var_->name_hint_, new_result_type, assign->var_->span_);
TileInfo info{half_dim_size};
tile_vars[assign->var_.get()] = info;
tile_vars[new_var.get()] = info;
if (split_info.split_applied) {
TileInfo info{split_info.dim_size};
tile_vars[assign->var_.get()] = info;
tile_vars[new_var.get()] = info;
}
var_replacements[assign->var_.get()] = new_var;
return std::make_shared<AssignStmt>(new_var, new_call, assign->span_);
}
Expand Down
46 changes: 46 additions & 0 deletions tests/ut/ir/transforms/test_split_vector_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,52 @@ def main_aiv(

_assert_split_matches_expected(Before, Expected)

def test_load_preserves_singleton_broadcast_dim(self):
"""UP_DOWN split should preserve singleton broadcast tiles like [1, N] in AIV."""

@pl.program
class Before:
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
def main_aiv(
self,
data: pl.Tensor[[16, 128], pl.FP32],
gamma: pl.Tensor[[1, 128], pl.FP32],
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
prev: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
data, [0, 0], [16, 128], target_memory=pl.MemorySpace.Vec
)
gamma_tile: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
gamma, [0, 0], [1, 128], target_memory=pl.MemorySpace.Vec
)
result: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Vec] = pl.col_expand_mul(prev, gamma_tile)
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(result, [0, 0], out_0)
return out_0_store

@pl.program
class Expected:
@pl.function(type=pl.FunctionType.AIV, attrs={"split": pl.SplitMode.UP_DOWN})
def main_aiv(
self,
data: pl.Tensor[[16, 128], pl.FP32],
gamma: pl.Tensor[[1, 128], pl.FP32],
out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
subblock_idx: pl.Scalar[pl.INT64] = pl.tile.get_subblock_idx()
prev: pl.Tile[[8, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
data, [0 + subblock_idx * 8, 0], [8, 128], target_memory=pl.MemorySpace.Vec
)
gamma_tile: pl.Tile[[1, 128], pl.FP32, pl.MemorySpace.Vec] = pl.load(
gamma, [0, 0], [1, 128], target_memory=pl.MemorySpace.Vec
)
result: pl.Tile[[8, 128], pl.FP32, pl.MemorySpace.Vec] = pl.col_expand_mul(prev, gamma_tile)
out_0_store: pl.Tensor[[16, 128], pl.FP32] = pl.store(
result, [0 + subblock_idx * 8, 0], out_0
)
return out_0_store

_assert_split_matches_expected(Before, Expected)

def test_loop_iter_arg_keeps_split_tracking(self):
"""Loop iter_args seeded by halved tiles must keep split-aware store offsets."""

Expand Down
Loading