Skip to content

[Feature] Composite API lowering mechanism: parameter-dependent dispatch to different underlying instructions #949

@lyfne123

Description

@lyfne123

Summary

APIs like expand_clone and scatter are composite operations: a single high-level API that, depending on its arguments (e.g., mode, reduction type, dtype, layout), must lower to completely different sets of underlying instructions. There is currently no principled mechanism for expressing and implementing this parameter-dependent lowering. A general framework is needed so composite APIs can be defined cleanly, with their lowering rules registered per-parameter-variant.

Motivation / Use Case

The problem with the current approach:

When implementing a composite op like expand_clone, the lowering logic ends up as a monolithic function full of parameter checks:

# Current: monolithic lowering with scattered conditionals
def lower_expand_clone(op, context):
    if op.mode == ExpandMode.BROADCAST:
        # emit broadcast instructions
        ...
    elif op.mode == ExpandMode.REPEAT:
        # emit repeat instructions
        ...
    elif op.mode == ExpandMode.TILE:
        # emit tile instructions
        ...
    # grows unboundedly as more modes are added

Similarly for scatter:

def lower_scatter(op, context):
    if op.reduction == ReductionKind.NONE:
        # emit direct scatter
    elif op.reduction == ReductionKind.ADD:
        # emit atomic add scatter
    elif op.reduction == ReductionKind.MAX:
        # emit atomic max scatter
    # ...

Problems:

What is needed:

A mechanism where each (op, parameter-variant) pair can register its own lowering rule independently:

# Proposed: each variant registers its lowering separately
@composite_lowering(ExpandClone, mode=ExpandMode.BROADCAST)
def lower_expand_clone_broadcast(op, context):
    # emit broadcast instructions only
    ...

@composite_lowering(ExpandClone, mode=ExpandMode.REPEAT)
def lower_expand_clone_repeat(op, context):
    # emit repeat instructions only
    ...

Proposed API / Behavior

A CompositeLoweringRegistry (or equivalent) that:

  1. Accepts registrations keyed by (OpType, **parameter_constraints) — e.g., (ExpandClone, mode=ExpandMode.BROADCAST)
  2. Dispatches at lowering time by matching the op's runtime parameters against registered rules, selecting the most specific match
  3. Raises a clear error if no rule matches (i.e., unimplemented variant), rather than silently falling through to wrong behavior
  4. Integrates with the existing pass infrastructure — the lowering pass queries the registry rather than containing inline dispatch logic

Sketch:

# Registration
registry = CompositeLoweringRegistry()

@registry.register(ExpandClone, mode=ExpandMode.BROADCAST)
def lower_broadcast(op, ctx): ...

@registry.register(ExpandClone, mode=ExpandMode.REPEAT)
def lower_repeat(op, ctx): ...

# Lowering pass — clean of variant logic
class ExpandCompositeLoweringPass(Pass):
    def visit_expand_clone(self, op):
        handler = self.registry.lookup(op)  # dispatches by op.mode
        handler(op, self.context)

Alternatives Considered

  • Subclass per variant (e.g., BroadcastExpandClone, RepeatExpandClone) — pushes the dispatch to the IR construction site and bloats the IR node hierarchy; users must pick the concrete subclass, which defeats the purpose of a unified high-level API
  • Inline conditionals per lowering pass — current approach; does not scale as variants and backends multiply (interacts badly with [Feature] Introduce a principled multi-backend dispatch mechanism to eliminate ad-hoc backendtype checks in passes #948)
  • Pattern-matching passes — each variant is a separate pass that matches on (op_type + parameter predicate) and rewrites; workable but heavyweight for simple variants, and ordering between variant passes must be managed carefully

Recommendation: a lightweight registry with decorator-based registration, similar to how PyTorch's CompositeImplicitAutograd and torch.library work.

Additional Context

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions