Skip to content

Commit 19f8526

Browse files
committed
feat(ir): Added expand_clone ops and substitute_tiles pass
1 parent d765fc0 commit 19f8526

28 files changed

Lines changed: 897 additions & 144 deletions

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ set(PYPTO_SOURCES
148148
src/ir/transforms/expand_mixed_kernel_pass.cpp
149149
src/ir/transforms/split_vector_kernel_pass.cpp
150150
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
151+
src/ir/transforms/substitute_tiles_pass.cpp
151152
src/ir/transforms/pass_context.cpp
152153
src/ir/transforms/passes.cpp
153154
src/ir/transforms/resolve_backend_op_layouts_pass.cpp

docs/en/dev/passes/00-pass_manager.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ struct PassProperties {
6868
| OutlineIncoreScopes | TypeChecked, SSAForm | SplitIncoreOrch | — |
6969
| OutlineClusterScopes | TypeChecked, SSAForm | ClusterOutlined | — |
7070
| ConvertTensorToTileOps | SplitIncoreOrch | IncoreTileOps | — |
71+
| SubstituteTiles | SSAForm, IncoreTileOps | SSAForm, IncoreTileOps | — |
7172
| FlattenTileNdTo2D | SSAForm, IncoreTileOps | SSAForm, TileOps2D | — |
7273
| ResolveBackendOpLayouts | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | NormalizedStmtStructure |
7374
| ExpandMixedKernel | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | SSAForm, MixedKernelExpanded | — |
@@ -359,15 +360,16 @@ with passes.PassContext([passes.VerificationInstrument(passes.VerificationMode.A
359360

360361
The PTO-oriented tile stage shared by `Default` and `DebugTileOptimization` is:
361362

362-
1. `FlattenTileNdTo2D`
363-
2. `InferTileMemorySpace`
364-
3. `ResolveTransposeLayout`
365-
4. `ResolveBackendOpLayouts`
366-
5. `ExpandMixedKernel`
367-
6. `InitMemRef`
368-
7. `MemoryReuse`
369-
8. `LegalizePTOBufferReuse`
370-
9. `AllocateMemoryAddr`
363+
1. `SubstituteTiles`
364+
2. `FlattenTileNdTo2D`
365+
3. `InferTileMemorySpace`
366+
4. `ResolveTransposeLayout`
367+
5. `ResolveBackendOpLayouts`
368+
6. `ExpandMixedKernel`
369+
7. `InitMemRef`
370+
8. `MemoryReuse`
371+
9. `LegalizePTOBufferReuse`
372+
10. `AllocateMemoryAddr`
371373

372374
`DebugTileOptimization` is a debug-only strategy for inspecting this tile stage
373375
without the tensor-only prefix passes. Use `Default` for normal compilation and

docs/zh-cn/dev/passes/00-pass_manager.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ struct PassProperties {
6868
| OutlineIncoreScopes | TypeChecked, SSAForm | SplitIncoreOrch | — |
6969
| OutlineClusterScopes | TypeChecked, SSAForm | ClusterOutlined | — |
7070
| ConvertTensorToTileOps | SplitIncoreOrch | IncoreTileOps | — |
71+
| SubstituteTiles | SSAForm, IncoreTileOps | SSAForm, IncoreTileOps | — |
7172
| FlattenTileNdTo2D | SSAForm, IncoreTileOps | SSAForm, TileOps2D | — |
7273
| ResolveBackendOpLayouts | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | NormalizedStmtStructure |
7374
| ExpandMixedKernel | SSAForm, IncoreTileOps, SplitIncoreOrch, TileOps2D | SSAForm, MixedKernelExpanded | — |
@@ -359,15 +360,16 @@ with passes.PassContext([passes.VerificationInstrument(passes.VerificationMode.A
359360

360361
`Default``DebugTileOptimization` 共享的 PTO tile 阶段顺序为:
361362

362-
1. `FlattenTileNdTo2D`
363-
2. `InferTileMemorySpace`
364-
3. `ResolveTransposeLayout`
365-
4. `ResolveBackendOpLayouts`
366-
5. `ExpandMixedKernel`
367-
6. `InitMemRef`
368-
7. `MemoryReuse`
369-
8. `LegalizePTOBufferReuse`
370-
9. `AllocateMemoryAddr`
363+
1. `SubstituteTiles`
364+
2. `FlattenTileNdTo2D`
365+
3. `InferTileMemorySpace`
366+
4. `ResolveTransposeLayout`
367+
5. `ResolveBackendOpLayouts`
368+
6. `ExpandMixedKernel`
369+
7. `InitMemRef`
370+
8. `MemoryReuse`
371+
9. `LegalizePTOBufferReuse`
372+
10. `AllocateMemoryAddr`
371373

372374
`DebugTileOptimization` 只是用于排查 PTO tile 阶段的调试策略,会跳过
373375
tensor-only 前缀 pass。正常编译和非 strategy 专项测试都应优先使用

include/pypto/ir/transforms/pass_properties.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ inline const PassProperties kConvertTensorToTileOpsProperties{
8383
.required = {IRProperty::SSAForm, IRProperty::SplitIncoreOrch, IRProperty::NormalizedStmtStructure},
8484
.produced = {IRProperty::SSAForm, IRProperty::IncoreTileOps, IRProperty::NormalizedStmtStructure}};
8585

86+
// -- Tile op substitution pass ----------------------------------------------
87+
88+
inline const PassProperties kSubstituteTilesProperties{
89+
.required = {IRProperty::SSAForm, IRProperty::IncoreTileOps, IRProperty::NormalizedStmtStructure},
90+
.produced = {IRProperty::SSAForm, IRProperty::IncoreTileOps, IRProperty::NormalizedStmtStructure}};
91+
8692
// -- Tile ND-to-2D flattening pass --------------------------------------------
8793

8894
inline const PassProperties kFlattenTileNdTo2DProperties{

include/pypto/ir/transforms/passes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ Pass OutlineClusterScopes();
265265
*/
266266
Pass ConvertTensorToTileOps();
267267

268+
/**
269+
* @brief Substitute unsupported tile ops with PTO-supported tile ops
270+
*
271+
* Rewrites tile ops that lack direct PTO instruction support into equivalent
272+
* combinations of supported tile ops.
273+
*
274+
* Requirements:
275+
* - Input IR must have tile ops (run ConvertTensorToTileOps first)
276+
*/
277+
Pass SubstituteTiles();
278+
268279
/**
269280
* @brief Flatten ND tile ops to 2D in InCore functions
270281
*

include/pypto/ir/type_inference.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#ifndef PYPTO_IR_TYPE_INFERENCE_H_
2222
#define PYPTO_IR_TYPE_INFERENCE_H_
2323

24+
#include <cstddef>
2425
#include <cstdint>
2526
#include <memory>
2627
#include <optional>
@@ -187,6 +188,17 @@ bool IsBroadcastable(const ExprPtr& source_dim, const ExprPtr& target_dim);
187188
*/
188189
std::string FormatShape(const std::vector<ExprPtr>& shape);
189190

191+
int NormalizeAxis(int axis, size_t ndim);
192+
193+
int64_t ComputeShapeProduct(const std::vector<ExprPtr>& shape);
194+
195+
bool IsIndexLikeDtype(DataType dtype);
196+
197+
TileLayout InferTileLayoutFromShape(const std::vector<ExprPtr>& shape);
198+
199+
void ValidateIndexTupleElements(const TupleTypePtr& tuple_type, const std::string& op_name,
200+
const std::string& arg_name);
201+
190202
/**
191203
* @brief Propagate blayout and pad from a source TileType's tile_view into a new TileView
192204
*

python/bindings/modules/passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ void BindPass(nb::module_& m) {
312312
"Create a pass that outlines Hierarchy scopes into separate level/role functions");
313313
passes.def("convert_tensor_to_tile_ops", &pass::ConvertTensorToTileOps,
314314
"Create a pass that converts tensor ops to tile ops in InCore functions");
315+
passes.def("substitute_tiles", &pass::SubstituteTiles,
316+
"Create a pass that substitutes unsupported tile ops with PTO-supported tile ops");
315317
passes.def("flatten_tile_nd_to_2d", &pass::FlattenTileNdTo2D,
316318
"Create a pass that flattens ND tile ops to 2D in InCore functions\n\n"
317319
"Merges all dimensions except the last into a single dimension.\n"

python/pypto/ir/op/tensor_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,33 @@ def expands(target: Expr, scalar: int | float | Expr, span: Span | None = None)
681681
return _ir_core.create_op_call("tensor.expands", [target, scalar_expr], {}, actual_span)
682682

683683

684+
def expand_clone(
685+
tensor: Expr,
686+
shape: list[int | Expr] | _ir_core.MakeTuple,
687+
valid_shape: list[int | Expr] | _ir_core.MakeTuple | None = None,
688+
span: Span | None = None,
689+
) -> Call:
690+
"""Expand tensor to new shape.
691+
692+
Args:
693+
tensor: Input tensor expression
694+
shape: New shape dimensions, or a MakeTuple
695+
valid_shape: Valid shape dimensions (optional, defaults to empty)
696+
span: Optional source span for debugging (auto-captured if not provided)
697+
698+
Returns:
699+
Call expression for tensor expand_clone
700+
"""
701+
actual_span = _get_span_or_capture(span)
702+
703+
shape_tuple = _to_make_tuple(shape, actual_span)
704+
705+
args = [tensor, shape_tuple]
706+
if valid_shape is not None:
707+
args.append(_to_make_tuple(valid_shape, actual_span))
708+
return _ir_core.create_op_call("tensor.expand_clone", args, {}, actual_span)
709+
710+
684711
def exp(input: Expr, span: Span | None = None) -> Call:
685712
"""Element-wise exponential operation.
686713

python/pypto/ir/op/tile_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,29 @@ def expands(target: Expr, scalar: int | float | Expr, span: Span | None = None)
15061506
return _ir_core.create_op_call("tile.expands", [target, scalar_expr], {}, actual_span)
15071507

15081508

1509+
def expand_clone(
1510+
tile: Expr,
1511+
shape: Sequence[int | Expr] | _ir_core.MakeTuple,
1512+
span: Span | None = None,
1513+
) -> Call:
1514+
"""Expand tile to new shape by cloning elements.
1515+
1516+
Args:
1517+
tile: Input tile expression
1518+
shape: New shape dimensions, or a MakeTuple
1519+
span: Optional source span for debugging (auto-captured if not provided)
1520+
1521+
Returns:
1522+
Call expression for tile expand_clone
1523+
"""
1524+
actual_span = _get_span_or_capture(span)
1525+
1526+
shape_tuple = _to_make_tuple(shape, actual_span)
1527+
1528+
args = [tile, shape_tuple]
1529+
return _ir_core.create_op_call("tile.expand_clone", args, {}, actual_span)
1530+
1531+
15091532
def maximum(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call:
15101533
"""Element-wise maximum of two tiles.
15111534

python/pypto/ir/pass_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _register_passes(cls):
131131
("ConvertTensorToTileOps", lambda: passes.convert_tensor_to_tile_ops()),
132132
]
133133
tile_pto_passes: list[PassSpec] = [
134+
("SubstituteTiles", lambda: passes.substitute_tiles()),
134135
("FlattenTileNdTo2D", lambda: passes.flatten_tile_nd_to_2d()),
135136
("InferTileMemorySpace", lambda: passes.infer_tile_memory_space()),
136137
("ResolveTransposeLayout", lambda: passes.resolve_transpose_layout()),

0 commit comments

Comments
 (0)