You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
APIs like expand_clone and scatter are composite operations: a single high-level API that, depending on its arguments (e.g., mode, reduction type, dtype, layout), must lower to completely different sets of underlying instructions. There is currently no principled mechanism for expressing and implementing this parameter-dependent lowering. A general framework is needed so composite APIs can be defined cleanly, with their lowering rules registered per-parameter-variant.
Motivation / Use Case
The problem with the current approach:
When implementing a composite op like expand_clone, the lowering logic ends up as a monolithic function full of parameter checks:
# Current: monolithic lowering with scattered conditionalsdeflower_expand_clone(op, context):
ifop.mode==ExpandMode.BROADCAST:
# emit broadcast instructions
...
elifop.mode==ExpandMode.REPEAT:
# emit repeat instructions
...
elifop.mode==ExpandMode.TILE:
# emit tile instructions
...
# grows unboundedly as more modes are added
Similarly for scatter:
deflower_scatter(op, context):
ifop.reduction==ReductionKind.NONE:
# emit direct scatterelifop.reduction==ReductionKind.ADD:
# emit atomic add scatterelifop.reduction==ReductionKind.MAX:
# emit atomic max scatter# ...
Problems:
Adding a new variant requires modifying the existing lowering function
No compile-time or registration-time guarantee that all variants are handled
Subclass per variant (e.g., BroadcastExpandClone, RepeatExpandClone) — pushes the dispatch to the IR construction site and bloats the IR node hierarchy; users must pick the concrete subclass, which defeats the purpose of a unified high-level API
Pattern-matching passes — each variant is a separate pass that matches on (op_type + parameter predicate) and rewrites; workable but heavyweight for simple variants, and ordering between variant passes must be managed carefully
Recommendation: a lightweight registry with decorator-based registration, similar to how PyTorch's CompositeImplicitAutograd and torch.library work.
Summary
APIs like
expand_cloneandscatterare composite operations: a single high-level API that, depending on its arguments (e.g., mode, reduction type, dtype, layout), must lower to completely different sets of underlying instructions. There is currently no principled mechanism for expressing and implementing this parameter-dependent lowering. A general framework is needed so composite APIs can be defined cleanly, with their lowering rules registered per-parameter-variant.Motivation / Use Case
The problem with the current approach:
When implementing a composite op like
expand_clone, the lowering logic ends up as a monolithic function full of parameter checks:Similarly for
scatter:Problems:
What is needed:
A mechanism where each (op, parameter-variant) pair can register its own lowering rule independently:
Proposed API / Behavior
A
CompositeLoweringRegistry(or equivalent) that:(OpType, **parameter_constraints)— e.g.,(ExpandClone, mode=ExpandMode.BROADCAST)Sketch:
Alternatives Considered
BroadcastExpandClone,RepeatExpandClone) — pushes the dispatch to the IR construction site and bloats the IR node hierarchy; users must pick the concrete subclass, which defeats the purpose of a unified high-level API(op_type + parameter predicate)and rewrites; workable but heavyweight for simple variants, and ordering between variant passes must be managed carefullyRecommendation: a lightweight registry with decorator-based registration, similar to how PyTorch's
CompositeImplicitAutogradandtorch.librarywork.Additional Context
expand_clone,scatter_,gather— that would be the first consumers of this mechanism)